11#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
22#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
33
4- #include <nbl/builtin/hlsl/cpp_compat.hlsl>
5- #include <nbl/builtin/hlsl/concepts.hlsl>
6- #include <nbl/builtin/hlsl/math/intutil.hlsl>
7- #include <nbl/builtin/hlsl/utility.hlsl>
4+ #include "nbl/builtin/hlsl/cpp_compat.hlsl"
5+ #include "nbl/builtin/hlsl/type_traits.hlsl"
6+ #include "nbl/builtin/hlsl/workgroup/basic.hlsl"
7+ #include "nbl/builtin/hlsl/functional.hlsl"
8+ #include "nbl/builtin/hlsl/concepts/impl/base.hlsl"
89
910namespace nbl
1011{
@@ -13,41 +14,174 @@ namespace hlsl
1314namespace bitonic_sort
1415{
1516
16- template<typename KeyType, typename ValueType, typename Comparator>
17- void compareExchangeWithPartner (
18- bool takeLarger,
19- NBL_REF_ARG (pair<KeyType, ValueType>) loPair,
20- NBL_CONST_REF_ARG (pair<KeyType, ValueType>) partnerLoPair,
21- NBL_REF_ARG (pair<KeyType, ValueType>) hiPair,
22- NBL_CONST_REF_ARG (pair<KeyType, ValueType>) partnerHiPair,
23- NBL_CONST_REF_ARG (Comparator) comp)
24- {
25- const bool loSelfSmaller = comp (loPair.first, partnerLoPair.first);
26- const bool takePartnerLo = takeLarger ? loSelfSmaller : !loSelfSmaller;
27- if (takePartnerLo)
28- loPair = partnerLoPair;
29-
30- const bool hiSelfSmaller = comp (hiPair.first, partnerHiPair.first);
31- const bool takePartnerHi = takeLarger ? hiSelfSmaller : !hiSelfSmaller;
32- if (takePartnerHi)
33- hiPair = partnerHiPair;
34- }
35-
36- template<typename KeyType, typename ValueType, typename Comparator>
37- void compareSwap (
38- bool ascending,
39- NBL_REF_ARG (pair<KeyType, ValueType>) loPair,
40- NBL_REF_ARG (pair<KeyType, ValueType>) hiPair,
41- NBL_CONST_REF_ARG (Comparator) comp)
17+ template<typename KeyType, typename ValueType, uint32_t SubgroupSizelog2,typename Comparator>
18+ struct bitonic_sort_config
4219{
43- const bool shouldSwap = comp (hiPair.first, loPair.first);
44- const bool doSwap = (shouldSwap == ascending);
20+ using key_t = KeyType;
21+ using value_t = ValueType;
22+ using comparator_t = Comparator;
23+ static const uint32_t SubgroupSize = SubgroupSizelog2;
24+ };
25+
26+ template<typename Config, class device_capabilities = void >
27+ struct bitonic_sort;
4528
46- if (doSwap)
47- swap (loPair, hiPair);
29+
30+ template<typename K, typename V>
31+ inline K get_key (NBL_CONST_REF_ARG (pair<K, V>) kv)
32+ {
33+ return kv.first;
4834}
35+
36+ template<typename K, typename V>
37+ inline V get_value (NBL_CONST_REF_ARG (pair<K, V>) kv)
38+ {
39+ return kv.second;
4940}
41+
42+ template<typename KeyType>
43+ struct WorkgroupType
44+ {
45+ using key_t = KeyType;
46+ using index_t = uint32_t;
47+
48+ key_t key;
49+ index_t workgroupRelativeIndex;
50+ };
51+
52+
53+ template<typename KeyType, uint32_t KeyBits, typename StorageType = uint32_t>
54+ struct SubgroupType
55+ {
56+ using key_t = KeyType;
57+ using storage_t = StorageType;
58+
59+ static const uint32_t IndexBits = sizeof (storage_t) * 8u - KeyBits;
60+ static const storage_t KeyMask = (storage_t (1u) << KeyBits) - 1u;
61+ static const storage_t IndexMask = ~KeyMask;
62+
63+ storage_t packed;
64+
65+ static inline SubgroupType create (key_t key, uint32_t subgroupIndex)
66+ {
67+ SubgroupType st;
68+ st.packed = (storage_t (key) & KeyMask) | (storage_t (subgroupIndex) << KeyBits);
69+ return st;
70+ }
71+
72+ inline key_t getKey () { return key_t (packed & KeyMask); }
73+ inline uint32_t getSubgroupIndex () { return packed >> KeyBits; }
74+
75+ inline WorkgroupType<key_t> toWorkgroupType (uint32_t subgroupID, uint32_t elementsPerSubgroup)
76+ {
77+ WorkgroupType<key_t> wg;
78+ wg.key = getKey ();
79+ wg.workgroupRelativeIndex = workgroup::SubgroupContiguousIndex ();
80+ return wg;
81+ }
82+ };
83+
84+ template<typename KeyType>
85+ using SubgroupType27 = SubgroupType<KeyType, 27u>;
86+
87+ template<typename K>
88+ inline K get_key (NBL_CONST_REF_ARG (WorkgroupType<K>) wt)
89+ {
90+ return wt.key;
5091}
92+
93+
94+ template<typename K, uint32_t KeyBits, typename StorageType>
95+ inline K get_key (NBL_CONST_REF_ARG (SubgroupType<K, KeyBits, StorageType>) st)
96+ {
97+ return st.getKey ();
5198}
5299
100+
101+ //template<typename KeyType, typename Comp>
102+ //inline void compareSwap(
103+ // bool ascending,
104+ // NBL_REF_ARG(WorkgroupType<KeyType>) a,
105+ // NBL_REF_ARG(WorkgroupType<KeyType>) b,
106+ // NBL_CONST_REF_ARG(Comp) comp)
107+ //{
108+ // const bool swap = comp(b.key, a.key) == ascending;
109+ // WorkgroupType<KeyType> tmp = a;
110+ // a = swap ? b : a;
111+ // b = swap ? tmp : b;
112+ //}
113+
114+
115+ //template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comp>
116+ //inline void compareSwap(
117+ // bool ascending,
118+ // NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) a,
119+ // NBL_REF_ARG(SubgroupType<KeyType, KeyBits, StorageType>) b,
120+ // NBL_CONST_REF_ARG(Comp) comp)
121+ //{
122+ // const bool swap = comp(b.getKey(), a.getKey()) == ascending;
123+ // SubgroupType<KeyType, KeyBits, StorageType> tmp = a;
124+ // a = swap ? b : a;
125+ // b = swap ? tmp : b;
126+ //}
127+
128+ template<typename KeyValue, uint32_t Log2N, typename Comparator>
129+ struct LocalPasses
130+ {
131+ static const uint32_t N = 1u << Log2N;
132+ void operator ()(bool ascending, KeyValue data[N], NBL_CONST_REF_ARG (Comparator) comp);
133+ };
134+
135+ template<typename KeyValue, typename Comparator>
136+ struct LocalPasses<KeyValue, 1 , Comparator>
137+ {
138+ void operator ()(bool ascending, KeyValue data[2 ], NBL_CONST_REF_ARG (Comparator) comp)
139+ {
140+ const bool swap = comp (get_key (data[1 ]), get_key (data[0 ])) == ascending;
141+
142+ KeyValue temp = data[0 ];
143+ data[0 ] = swap ? data[1 ] : data[0 ];
144+ data[1 ] = swap ? temp : data[1 ];
145+ }
146+ };
147+
148+ // Specialization for WorkgroupType with 2 elements
149+ template<typename KeyType, typename Comparator>
150+ struct LocalPasses<WorkgroupType<KeyType>, 1 , Comparator>
151+ {
152+ static const uint32_t N = 2 ;
153+
154+ void operator ()(bool ascending,
155+ WorkgroupType<KeyType> data[N],
156+ NBL_CONST_REF_ARG (Comparator) comp)
157+ {
158+ const bool swap = comp (get_key (data[1 ]), get_key (data[0 ])) == ascending;
159+ WorkgroupType<KeyType> tmp = data[0 ];
160+ data[0 ] = swap ? data[1 ] : data[0 ];
161+ data[1 ] = swap ? tmp : data[1 ];
162+ }
163+ };
164+
165+ // Specialization for SubgroupType with 2 elements
166+ template<typename KeyType, uint32_t KeyBits, typename StorageType, typename Comparator>
167+ struct LocalPasses<SubgroupType<KeyType, KeyBits, StorageType>, 1 , Comparator>
168+ {
169+ static const uint32_t N = 2 ;
170+
171+ void operator ()(bool ascending,
172+ SubgroupType<KeyType, KeyBits, StorageType> data[N],
173+ NBL_CONST_REF_ARG (Comparator) comp)
174+ {
175+ const bool swap = comp (get_key (data[1 ]), get_key (data[0 ])) == ascending;
176+ SubgroupType<KeyType, KeyBits, StorageType> tmp = data[0 ];
177+ data[0 ] = swap ? data[1 ] : data[0 ];
178+ data[1 ] = swap ? tmp : data[1 ];
179+ }
180+ };
181+
182+
183+ } // namespace bitonic_sort
184+ } // namespace hlsl
185+ } // namespace nbl
186+
53187#endif
0 commit comments