2
2
#define _NBL_BUILTIN_HLSL_WORKGROUP_SHUFFLE_INCLUDED_
3
3
4
4
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
5
+ #include "nbl/builtin/hlsl/functional.hlsl"
5
6
6
7
// TODO: Add other shuffles
7
8
@@ -14,26 +15,87 @@ namespace hlsl
14
15
namespace workgroup
15
16
{
16
17
18
+ // ------------------------------------- Skeletons for implementing other Shuffles --------------------------------
19
+
17
20
template<typename SharedMemoryAdaptor, typename T>
18
- struct shuffleXor
21
+ struct Shuffle
22
+ {
23
+ static void __call (NBL_REF_ARG (T) value, uint32_t storeIdx, uint32_t loadIdx, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
24
+ {
25
+ // TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
26
+ sharedmemAdaptor.template set<T>(storeIdx, value);
27
+
28
+ // Wait until all writes are done before reading
29
+ sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
30
+
31
+ sharedmemAdaptor.template get<T>(loadIdx, value);
32
+ }
33
+
34
+ // By default store to threadID in the workgroup
35
+ static void __call (NBL_REF_ARG (T) value, uint32_t loadIdx, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
36
+ {
37
+ __call (value, uint32_t (SubgroupContiguousIndex ()), loadIdx, sharedmemAdaptor);
38
+ }
39
+ };
40
+
41
+ template<class UnOp, typename SharedMemoryAdaptor, typename T>
42
+ struct ShuffleUnOp
43
+ {
44
+ static void __call (NBL_REF_ARG (T) value, uint32_t a, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
45
+ {
46
+ UnOp unop;
47
+ // TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
48
+ sharedmemAdaptor.template set<T>(a, value);
49
+
50
+ // Wait until all writes are done before reading
51
+ sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
52
+
53
+ sharedmemAdaptor.template get<T>(unop (a), value);
54
+ }
55
+
56
+ // By default store to threadID's index and load from unop(threadID)
57
+ static void __call (NBL_REF_ARG (T) value, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
58
+ {
59
+ __call (value, uint32_t (SubgroupContiguousIndex ()), sharedmemAdaptor);
60
+ }
61
+ };
62
+
63
+ template<class BinOp, typename SharedMemoryAdaptor, typename T>
64
+ struct ShuffleBinOp
19
65
{
20
- static void __call (NBL_REF_ARG (T) value, uint32_t mask , uint32_t threadID , NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
66
+ static void __call (NBL_REF_ARG (T) value, uint32_t a , uint32_t b , NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
21
67
{
68
+ BinOp binop;
22
69
// TODO: optimization (optional) where we shuffle in the shared memory available (using rounds)
23
- sharedmemAdaptor.template set<T>(threadID , value);
24
-
70
+ sharedmemAdaptor.template set<T>(a , value);
71
+
25
72
// Wait until all writes are done before reading
26
73
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
27
-
28
- sharedmemAdaptor.template get<T>(threadID ^ mask , value);
74
+
75
+ sharedmemAdaptor.template get<T>(binop (a, b) , value);
29
76
}
30
77
31
- static void __call (NBL_REF_ARG (T) value, uint32_t mask, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
78
+ // By default first argument of binary op is the thread's ID in the workgroup
79
+ static void __call (NBL_REF_ARG (T) value, uint32_t b, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
32
80
{
33
- __call (value, mask, uint32_t (SubgroupContiguousIndex ()), sharedmemAdaptor);
81
+ __call (value, uint32_t (SubgroupContiguousIndex ()), b , sharedmemAdaptor);
34
82
}
35
83
};
36
84
85
+ // ------------------------------------------ ShuffleXor ---------------------------------------------------------------
86
+
87
+ template<typename SharedMemoryAdaptor, typename T>
88
+ void shuffleXor (NBL_REF_ARG (T) value, uint32_t threadID, uint32_t mask, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
89
+ {
90
+ return ShuffleBinOp<bit_xor<uint32_t>, SharedMemoryAdaptor, T>::__call (value, threadID, mask, sharedmemAdaptor);
91
+ }
92
+
93
+ template<typename SharedMemoryAdaptor, typename T>
94
+ void shuffleXor (NBL_REF_ARG (T) value, uint32_t mask, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
95
+ {
96
+ return ShuffleBinOp<bit_xor<uint32_t>, SharedMemoryAdaptor, T>::__call (value, mask, sharedmemAdaptor);
97
+ }
98
+
37
99
}
38
100
}
39
101
}
0 commit comments