22#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
33
44#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5- #include "nbl/builtin/hlsl/type_traits.hlsl"
6- #include "nbl/builtin/hlsl/workgroup/basic.hlsl"
75#include "nbl/builtin/hlsl/functional.hlsl"
8- #include "nbl/builtin/hlsl/concepts/impl/base.hlsl"
96
107namespace nbl
118{
@@ -14,168 +11,39 @@ namespace hlsl
1411namespace bitonic_sort
1512{
1613
17- template<typename KeyType, typename ValueType, uint32_t SubgroupSizelog2,typename Comparator>
14+ template<typename KeyType, typename ValueType, uint32_t SubgroupSizelog2, typename Comparator>
1815struct bitonic_sort_config
1916{
2017 using key_t = KeyType;
2118 using value_t = ValueType;
2219 using comparator_t = Comparator;
23- static const uint32_t SubgroupSize = SubgroupSizelog2;
20+ static const uint32_t SubgroupSizeLog2 = SubgroupSizelog2;
21+ static const uint32_t SubgroupSize = 1u << SubgroupSizeLog2;
2422};
2523
2624template<typename Config, class device_capabilities = void >
2725struct bitonic_sort;
2826
2927
30- template<typename K, typename V>
31- inline K get_key (NBL_CONST_REF_ARG (pair<K, V>) kv)
32- {
33- return kv.first;
34- }
35-
36- template<typename K, typename V>
37- inline V get_value (NBL_CONST_REF_ARG (pair<K, V>) kv)
38- {
39- return kv.second;
40- }
41-
42- template<typename KeyType>
43- struct WorkgroupType
44- {
45- using key_t = KeyType;
46- using index_t = uint32_t;
47-
48- key_t key;
49- index_t workgroupRelativeIndex;
50- };
51-
52-
53- template<typename KeyType, uint32_t KeyBits, typename StorageType = uint32_t>
54- struct SubgroupType
55- {
56- using key_t = KeyType;
57- using storage_t = StorageType;
58-
59- static const uint32_t IndexBits = sizeof (storage_t) * 8u - KeyBits;
60- static const storage_t KeyMask = (storage_t (1u) << KeyBits) - 1u;
61- static const storage_t IndexMask = ~KeyMask;
62-
63- storage_t packed;
64-
65- static inline SubgroupType create (key_t key, uint32_t subgroupIndex)
66- {
67- SubgroupType st;
68- st.packed = (storage_t (key) & KeyMask) | (storage_t (subgroupIndex) << KeyBits);
69- return st;
70- }
71-
72- inline key_t getKey () { return key_t (packed & KeyMask); }
73- inline uint32_t getSubgroupIndex () { return packed >> KeyBits; }
74-
75- inline WorkgroupType<key_t> toWorkgroupType (uint32_t subgroupID, uint32_t elementsPerSubgroup)
76- {
77- WorkgroupType<key_t> wg;
78- wg.key = getKey ();
79- wg.workgroupRelativeIndex = workgroup::SubgroupContiguousIndex ();
80- return wg;
81- }
82- };
83-
84- template<typename KeyType>
85- using SubgroupType27 = SubgroupType<KeyType, 27u>;
86-
87- template<typename K>
88- inline K get_key (NBL_CONST_REF_ARG (WorkgroupType<K>) wt)
89- {
90- return wt.key;
91- }
92-
93-
94- template<typename K, uint32_t KeyBits, typename StorageType>
95- inline K get_key (NBL_CONST_REF_ARG (SubgroupType<K, KeyBits, StorageType>) st)
96- {
97- return st.getKey ();
98- }
99-
100-
101- //template<typename KeyType, typename Comp>
102- //inline void compareSwap(
103- // bool ascending,
104- // NBL_REF_ARG(WorkgroupType<KeyType>) a,
105- // NBL_REF_ARG(WorkgroupType<KeyType>) b,
106- // NBL_CONST_REF_ARG(Comp) comp)
107- //{
108- // const bool swap = comp(b.key, a.key) == ascending;
109- // WorkgroupType<KeyType> tmp = a;
110- // a = swap ? b : a;
111- // b = swap ? tmp : b;
112- //}
113-
114-
115- //template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comp>
116- //inline void compareSwap(
117- // bool ascending,
118- // NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) a,
119- // NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) b,
120- // NBL_CONST_REF_ARG(Comp) comp)
121- //{
122- // const bool swap = comp(b.getKey(), a.getKey()) == ascending;
123- // SubgroupType<KeyType, KeyBits, StorageType> tmp = a;
124- // a = swap ? b : a;
125- // b = swap ? tmp : b;
126- //}
127-
12828template<typename KeyValue, uint32_t Log2N, typename Comparator>
12929struct LocalPasses
13030{
131- static const uint32_t N = 1u << Log2N;
132- void operator ()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG (Comparator) comp);
31+ static const uint32_t N = 1u << Log2N;
32+ void operator ()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG (Comparator) comp);
13333};
13434
13535template<typename KeyValue, typename Comparator>
13636struct LocalPasses<KeyValue, 1 , Comparator>
137- {
138- void operator ()(bool ascending, KeyValue data[2 ], NBL_CONST_REF_ARG (Comparator) comp)
139- {
140- const bool swap = comp (get_key (data[1 ]), get_key (data[0 ])) == ascending;
141-
142- KeyValue temp = data[0 ];
143- data[0 ] = swap ? data[1 ] : data[0 ];
144- data[1 ] = swap ? temp : data[1 ];
145- }
146- };
147-
148- // Specialization for WorkgroupType with 2 elements
149- template<typename KeyType, typename Comparator>
150- struct LocalPasses<WorkgroupType<KeyType>, 1 , Comparator>
15137{
15238 static const uint32_t N = 2 ;
15339
154- void operator ()(bool ascending,
155- WorkgroupType<KeyType> data[N],
156- NBL_CONST_REF_ARG (Comparator) comp)
40+ void operator ()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG (Comparator) comp)
15741 {
158- const bool swap = comp (get_key (data[1 ]), get_key (data[0 ])) == ascending;
159- WorkgroupType<KeyType> tmp = data[0 ];
160- data[0 ] = swap ? data[1 ] : data[0 ];
161- data[1 ] = swap ? tmp : data[1 ];
162- }
163- };
42+ const bool shouldSwap = comp (data[1 ], data[0 ]) == ascending;
16443
165- // Specialization for SubgroupType with 2 elements
166- template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comparator>
167- struct LocalPasses<SubgroupType<KeyType, KeyBits, StorageType>, 1 , Comparator>
168- {
169- static const uint32_t N = 2 ;
170-
171- void operator ()(bool ascending,
172- SubgroupType<KeyType, KeyBits, StorageType> data[N],
173- NBL_CONST_REF_ARG (Comparator) comp)
174- {
175- const bool swap = comp (get_key (data[1 ]), get_key (data[0 ])) == ascending;
176- SubgroupType<KeyType, KeyBits, StorageType> tmp = data[0 ];
177- data[0 ] = swap ? data[1 ] : data[0 ];
178- data[1 ] = swap ? tmp : data[1 ];
44+ KeyValue temp = data[0 ];
45+ data[0 ] = shouldSwap ? data[1 ] : data[0 ];
46+ data[1 ] = shouldSwap ? temp : data[1 ];
17947 }
18048};
18149
0 commit comments