Skip to content

Commit 6f2ee43

Browse files
committed
update common.hlsl
1 parent 9d56528 commit 6f2ee43

File tree

1 file changed

+10
-142
lines changed

1 file changed

+10
-142
lines changed

include/nbl/builtin/hlsl/bitonic_sort/common.hlsl

Lines changed: 10 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
33

44
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5-
#include "nbl/builtin/hlsl/type_traits.hlsl"
6-
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
75
#include "nbl/builtin/hlsl/functional.hlsl"
8-
#include "nbl/builtin/hlsl/concepts/impl/base.hlsl"
96

107
namespace nbl
118
{
@@ -14,168 +11,39 @@ namespace hlsl
1411
namespace bitonic_sort
1512
{
1613

17-
template<typename KeyType, typename ValueType, uint32_t SubgroupSizelog2,typename Comparator>
14+
template<typename KeyType, typename ValueType, uint32_t SubgroupSizelog2, typename Comparator>
1815
struct bitonic_sort_config
1916
{
2017
using key_t = KeyType;
2118
using value_t = ValueType;
2219
using comparator_t = Comparator;
23-
static const uint32_t SubgroupSize = SubgroupSizelog2;
20+
static const uint32_t SubgroupSizeLog2 = SubgroupSizelog2;
21+
static const uint32_t SubgroupSize = 1u << SubgroupSizeLog2;
2422
};
2523

2624
template<typename Config, class device_capabilities = void>
2725
struct bitonic_sort;
2826

2927

30-
template<typename K, typename V>
31-
inline K get_key(NBL_CONST_REF_ARG(pair<K, V>) kv)
32-
{
33-
return kv.first;
34-
}
35-
36-
template<typename K, typename V>
37-
inline V get_value(NBL_CONST_REF_ARG(pair<K, V>) kv)
38-
{
39-
return kv.second;
40-
}
41-
42-
template<typename KeyType>
43-
struct WorkgroupType
44-
{
45-
using key_t = KeyType;
46-
using index_t = uint32_t;
47-
48-
key_t key;
49-
index_t workgroupRelativeIndex;
50-
};
51-
52-
53-
template<typename KeyType, uint32_t KeyBits, typename StorageType = uint32_t>
54-
struct SubgroupType
55-
{
56-
using key_t = KeyType;
57-
using storage_t = StorageType;
58-
59-
static const uint32_t IndexBits = sizeof(storage_t) * 8u - KeyBits;
60-
static const storage_t KeyMask = (storage_t(1u) << KeyBits) - 1u;
61-
static const storage_t IndexMask = ~KeyMask;
62-
63-
storage_t packed;
64-
65-
static inline SubgroupType create(key_t key, uint32_t subgroupIndex)
66-
{
67-
SubgroupType st;
68-
st.packed = (storage_t(key) & KeyMask) | (storage_t(subgroupIndex) << KeyBits);
69-
return st;
70-
}
71-
72-
inline key_t getKey() { return key_t(packed & KeyMask); }
73-
inline uint32_t getSubgroupIndex() { return packed >> KeyBits; }
74-
75-
inline WorkgroupType<key_t> toWorkgroupType(uint32_t subgroupID, uint32_t elementsPerSubgroup)
76-
{
77-
WorkgroupType<key_t> wg;
78-
wg.key = getKey();
79-
wg.workgroupRelativeIndex = workgroup::SubgroupContiguousIndex();
80-
return wg;
81-
}
82-
};
83-
84-
template<typename KeyType>
85-
using SubgroupType27 = SubgroupType<KeyType, 27u>;
86-
87-
template<typename K>
88-
inline K get_key(NBL_CONST_REF_ARG(WorkgroupType<K>) wt)
89-
{
90-
return wt.key;
91-
}
92-
93-
94-
template<typename K, uint32_t KeyBits, typename StorageType>
95-
inline K get_key(NBL_CONST_REF_ARG(SubgroupType<K, KeyBits, StorageType>) st)
96-
{
97-
return st.getKey();
98-
}
99-
100-
101-
//template<typename KeyType, typename Comp>
102-
//inline void compareSwap(
103-
// bool ascending,
104-
// NBL_REF_ARG(WorkgroupType<KeyType>) a,
105-
// NBL_REF_ARG(WorkgroupType<KeyType>) b,
106-
// NBL_CONST_REF_ARG(Comp) comp)
107-
//{
108-
// const bool swap = comp(b.key, a.key) == ascending;
109-
// WorkgroupType<KeyType> tmp = a;
110-
// a = swap ? b : a;
111-
// b = swap ? tmp : b;
112-
//}
113-
114-
115-
//template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comp>
116-
//inline void compareSwap(
117-
// bool ascending,
118-
// NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) a,
119-
// NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) b,
120-
// NBL_CONST_REF_ARG(Comp) comp)
121-
//{
122-
// const bool swap = comp(b.getKey(), a.getKey()) == ascending;
123-
// SubgroupType<KeyType, KeyBits, StorageType> tmp = a;
124-
// a = swap ? b : a;
125-
// b = swap ? tmp : b;
126-
//}
127-
12828
template<typename KeyValue, uint32_t Log2N, typename Comparator>
12929
struct LocalPasses
13030
{
131-
static const uint32_t N = 1u << Log2N;
132-
void operator()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG(Comparator) comp);
31+
static const uint32_t N = 1u << Log2N;
32+
void operator()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG(Comparator) comp);
13333
};
13434

13535
template<typename KeyValue, typename Comparator>
13636
struct LocalPasses<KeyValue, 1, Comparator>
137-
{
138-
void operator()(bool ascending, KeyValue data[2], NBL_CONST_REF_ARG(Comparator) comp)
139-
{
140-
const bool swap = comp(get_key(data[1]), get_key(data[0])) == ascending;
141-
142-
KeyValue temp = data[0];
143-
data[0] = swap ? data[1] : data[0];
144-
data[1] = swap ? temp : data[1];
145-
}
146-
};
147-
148-
// Specialization for WorkgroupType with 2 elements
149-
template<typename KeyType, typename Comparator>
150-
struct LocalPasses<WorkgroupType<KeyType>, 1, Comparator>
15137
{
15238
static const uint32_t N = 2;
15339

154-
void operator()(bool ascending,
155-
WorkgroupType<KeyType> data[N],
156-
NBL_CONST_REF_ARG(Comparator) comp)
40+
void operator()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG(Comparator) comp)
15741
{
158-
const bool swap = comp(get_key(data[1]), get_key(data[0])) == ascending;
159-
WorkgroupType<KeyType> tmp = data[0];
160-
data[0] = swap ? data[1] : data[0];
161-
data[1] = swap ? tmp : data[1];
162-
}
163-
};
42+
const bool shouldSwap = comp(data[1], data[0]) == ascending;
16443

165-
// Specialization for SubgroupType with 2 elements
166-
template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comparator>
167-
struct LocalPasses<SubgroupType<KeyType, KeyBits, StorageType>, 1, Comparator>
168-
{
169-
static const uint32_t N = 2;
170-
171-
void operator()(bool ascending,
172-
SubgroupType<KeyType, KeyBits, StorageType> data[N],
173-
NBL_CONST_REF_ARG(Comparator) comp)
174-
{
175-
const bool swap = comp(get_key(data[1]), get_key(data[0])) == ascending;
176-
SubgroupType<KeyType, KeyBits, StorageType> tmp = data[0];
177-
data[0] = swap ? data[1] : data[0];
178-
data[1] = swap ? tmp : data[1];
44+
KeyValue temp = data[0];
45+
data[0] = shouldSwap ? data[1] : data[0];
46+
data[1] = shouldSwap ? temp : data[1];
17947
}
18048
};
18149

0 commit comments

Comments
 (0)