@@ -11,7 +11,8 @@ SPDX-License-Identifier: MIT
1111constant uint RADIX_SORT_BITS_PER_PASS = 4 ;
1212constant uint RADIX_SORT_CHAR_BIT = 8 ;
1313
14- /* Default devicelib sub-group sort - bitonic sorting network, value-only */
14+ /* Default devicelib sub-group sort - bitonic sorting network.
15+ Two versions: key-value and value only. */
1516
1617uint __builtin_sub_group_sort_mirror (uint idx , uint base )
1718{
@@ -100,6 +101,87 @@ type OVERLOADABLE __builtin_sub_group_sort32(const type aa, const bool is_asc) \
100101 __builtin_sub_group_sort_rotate(slotID, 2), \
101102 __builtin_sub_group_sort_sel(slotID, 2), is_asc); \
102103 return gg; \
104+ } \
105+ \
106+ void OVERLOADABLE __builtin_sub_group_sort_compare_exchange_kv( \
107+ type *key0, uint *val0, const uint shuffleMask, const uint selectMask, \
108+ const bool is_asc) \
109+ { \
110+ type key1 = sub_group_shuffle(*key0, shuffleMask); \
111+ type val1 = sub_group_shuffle(*val0, shuffleMask); \
112+ type key_min = min(*key0, key1); \
113+ type key_max = max(*key0, key1); \
114+ type val_min = (*key0 < key1) ? *val0 : val1; \
115+ type val_max = (*key0 <= key1) ? val1 : *val0; \
116+ if (selectMask) { \
117+ *key0 = is_asc ? key_max : key_min; \
118+ *val0 = is_asc ? val_max : val_min; \
119+ } else { \
120+ *key0 = is_asc ? key_min : key_max; \
121+ *val0 = is_asc ? val_min : val_max; \
122+ } \
123+ } \
124+ void OVERLOADABLE __builtin_sub_group_sort8_kv( \
125+ type *key, uint *val, const bool is_asc) \
126+ { \
127+ const uint slotID = get_sub_group_local_id(); \
128+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
129+ __builtin_sub_group_sort_mirror(slotID, 2), \
130+ __builtin_sub_group_sort_sel(slotID, 2), is_asc); \
131+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
132+ __builtin_sub_group_sort_mirror(slotID, 4), \
133+ __builtin_sub_group_sort_sel(slotID, 4), is_asc); \
134+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
135+ __builtin_sub_group_sort_mirror(slotID, 2), \
136+ __builtin_sub_group_sort_sel(slotID, 2), is_asc); \
137+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
138+ __builtin_sub_group_sort_mirror(slotID, 8), \
139+ __builtin_sub_group_sort_sel(slotID, 8), is_asc); \
140+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
141+ __builtin_sub_group_sort_rotate(slotID, 4), \
142+ __builtin_sub_group_sort_sel(slotID, 4), is_asc); \
143+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
144+ __builtin_sub_group_sort_rotate(slotID, 2), \
145+ __builtin_sub_group_sort_sel(slotID, 2), is_asc); \
146+ } \
147+ void OVERLOADABLE __builtin_sub_group_sort16_kv( \
148+ type *key, uint *val, const bool is_asc) \
149+ { \
150+ const uint slotID = get_sub_group_local_id(); \
151+ __builtin_sub_group_sort8_kv(key, val, is_asc); \
152+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
153+ __builtin_sub_group_sort_mirror(slotID, 16), \
154+ __builtin_sub_group_sort_sel(slotID, 16),is_asc); \
155+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
156+ __builtin_sub_group_sort_rotate(slotID, 8), \
157+ __builtin_sub_group_sort_sel(slotID, 8), is_asc); \
158+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
159+ __builtin_sub_group_sort_rotate(slotID, 4), \
160+ __builtin_sub_group_sort_sel(slotID, 4), is_asc); \
161+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
162+ __builtin_sub_group_sort_rotate(slotID, 2), \
163+ __builtin_sub_group_sort_sel(slotID, 2), is_asc); \
164+ } \
165+ void OVERLOADABLE __builtin_sub_group_sort32_kv( \
166+ type *key, uint *val, const bool is_asc) \
167+ { \
168+ const uint slotID = get_sub_group_local_id(); \
169+ __builtin_sub_group_sort16_kv(key, val, is_asc); \
170+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
171+ __builtin_sub_group_sort_mirror(slotID, 32), \
172+ __builtin_sub_group_sort_sel(slotID, 32), is_asc); \
173+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
174+ __builtin_sub_group_sort_rotate(slotID, 16), \
175+ __builtin_sub_group_sort_sel(slotID, 16), is_asc); \
176+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
177+ __builtin_sub_group_sort_rotate(slotID, 8), \
178+ __builtin_sub_group_sort_sel(slotID, 8), is_asc); \
179+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
180+ __builtin_sub_group_sort_rotate(slotID, 4), \
181+ __builtin_sub_group_sort_sel(slotID, 4), is_asc); \
182+ __builtin_sub_group_sort_compare_exchange_kv(key, val, \
183+ __builtin_sub_group_sort_rotate(slotID, 2), \
184+ __builtin_sub_group_sort_sel(slotID, 2), is_asc); \
103185}
104186
105187
@@ -1450,9 +1532,43 @@ type __builtin_IB_sub_group_clustered_sort_##direction##_##type_abbr(
14501532 return sorted; \
14511533}
14521534
1535+ // clustered sorted ordinal - returns cluster lane when value is sorted
1536+ // Example: SIMD16 input:
1537+ // 0, 2,19, 4, 1, 5, 7, 9, 19, 7,18, 4,10, 5, 2, 3
1538+ // Result after sorted_ordinal_descend, cluster_size=8:
1539+ // 7, 5, 0, 4, 6, 3, 2, 1, 0, 3, 1, 5, 2, 4, 7, 6
1540+ #define DEFN_CLUSTERED_SUB_GROUP_SORTED_ORDINAL (type , type_abbr , direction , is_asc ) \
1541+ type __builtin_IB_sub_group_clustered_sorted_ordinal_##direction##_##type_abbr( \
1542+ type value,uint cluster_size) \
1543+ { \
1544+ type k = value; \
1545+ uint v = get_sub_group_local_id(); \
1546+ uint result = get_sub_group_local_id(); \
1547+ switch (cluster_size) \
1548+ { \
1549+ case 8: \
1550+ __builtin_sub_group_sort8_kv(&k, &v, is_asc); \
1551+ __builtin_sub_group_sort8_kv(&v, &result, true); \
1552+ break; \
1553+ case 16: \
1554+ __builtin_sub_group_sort16_kv(&k, &v, is_asc); \
1555+ __builtin_sub_group_sort16_kv(&v, &result, true); \
1556+ break; \
1557+ case 32: \
1558+ __builtin_sub_group_sort32_kv(&k, &v, is_asc); \
1559+ __builtin_sub_group_sort32_kv(&v, &result, true); \
1560+ break; \
1561+ default: \
1562+ break; \
1563+ } \
1564+ return result % cluster_size; \
1565+ }
1566+
14531567#define DEFN_CLUSTERED_SUB_GROUP_SORT (type , type_abbr ) \
14541568 DEFN_CLUSTERED_SUB_GROUP_SORT_KEY_ONLY(type, type_abbr, ascend, true) \
1455- DEFN_CLUSTERED_SUB_GROUP_SORT_KEY_ONLY(type, type_abbr, descend, false)
1569+ DEFN_CLUSTERED_SUB_GROUP_SORT_KEY_ONLY(type, type_abbr, descend, false) \
1570+ DEFN_CLUSTERED_SUB_GROUP_SORTED_ORDINAL(type, type_abbr, ascend, true) \
1571+ DEFN_CLUSTERED_SUB_GROUP_SORTED_ORDINAL(type, type_abbr, descend, false)
14561572
14571573DEFN_CLUSTERED_SUB_GROUP_SORT (char , i8 )
14581574DEFN_CLUSTERED_SUB_GROUP_SORT (short , i16 )
0 commit comments