7
7
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
8
8
#include "nbl/builtin/hlsl/mpl.hlsl"
9
9
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
10
+ #include "nbl/builtin/hlsl/bit.hlsl"
11
+
12
+ // Caveats
13
+ // - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
10
14
11
15
namespace nbl
12
16
{
@@ -18,20 +22,77 @@ namespace fft
18
22
{
19
23
20
24
// ---------------------------------- Utils -----------------------------------------------
25
+ template<typename SharedMemoryAdaptor, typename Scalar>
26
+ struct exchangeValues;
21
27
22
- template<typename SharedMemoryAccessor, typename Scalar >
23
- void exchangeValues ( NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
28
+ template<typename SharedMemoryAdaptor >
29
+ struct exchangeValues<SharedMemoryAdaptor, float16_t>
24
30
{
25
- const bool topHalf = bool (threadID & stride);
26
- // Ternary won't take structs so we use this aux variable
27
- vector <Scalar, 2 > toExchange = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
28
- complex_t<Scalar> toExchangeComplex = {toExchange.x, toExchange.y};
29
- shuffleXor<SharedMemoryAccessor, complex_t<Scalar> >::__call (toExchangeComplex, stride, sharedmemAccessor);
30
- if (topHalf)
31
- lo = toExchangeComplex;
32
- else
33
- hi = toExchangeComplex;
34
- }
31
+ static void __call (NBL_REF_ARG (complex_t<float16_t>) lo, NBL_REF_ARG (complex_t<float16_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
32
+ {
33
+ const bool topHalf = bool (threadID & stride);
34
+ // Ternary won't take structs so we use this aux variable
35
+ uint32_t toExchange = bit_cast<uint32_t, vector <float16_t, 2 > >(topHalf ? vector <float16_t, 2 > (lo.real (), lo.imag ()) : vector <float16_t, 2 > (hi.real (), hi.imag ()));
36
+ shuffleXor<SharedMemoryAdaptor, uint32_t>::__call (toExchange, stride, sharedmemAdaptor);
37
+ vector <float16_t, 2 > exchanged = bit_cast<vector <float16_t, 2 >, uint32_t>(toExchange);
38
+ if (topHalf)
39
+ {
40
+ lo.real (exchanged.x);
41
+ lo.imag (exchanged.y);
42
+ }
43
+ else
44
+ {
45
+ hi.real (exchanged.x);
46
+ lo.imag (exchanged.y);
47
+ }
48
+ }
49
+ };
50
+
51
+ template<typename SharedMemoryAdaptor>
52
+ struct exchangeValues<SharedMemoryAdaptor, float32_t>
53
+ {
54
+ static void __call (NBL_REF_ARG (complex_t<float32_t>) lo, NBL_REF_ARG (complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
55
+ {
56
+ const bool topHalf = bool (threadID & stride);
57
+ // Ternary won't take structs so we use this aux variable
58
+ vector <uint32_t, 2 > toExchange = bit_cast<vector <uint32_t, 2 >, vector <float32_t, 2 > >(topHalf ? vector <float32_t, 2 >(lo.real (), lo.imag ()) : vector <float32_t, 2 >(hi.real (), hi.imag ()));
59
+ shuffleXor<SharedMemoryAdaptor, vector <uint32_t, 2 > >::__call (toExchange, stride, sharedmemAdaptor);
60
+ vector <float32_t, 2 > exchanged = bit_cast<vector <float32_t, 2 >, vector <uint32_t, 2 > >(toExchange);
61
+ if (topHalf)
62
+ {
63
+ lo.real (exchanged.x);
64
+ lo.imag (exchanged.y);
65
+ }
66
+ else
67
+ {
68
+ hi.real (exchanged.x);
69
+ hi.imag (exchanged.y);
70
+ }
71
+ }
72
+ };
73
+
74
+ template<typename SharedMemoryAdaptor>
75
+ struct exchangeValues<SharedMemoryAdaptor, float64_t>
76
+ {
77
+ static void __call (NBL_REF_ARG (complex_t<float64_t>) lo, NBL_REF_ARG (complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
78
+ {
79
+ const bool topHalf = bool (threadID & stride);
80
+ // Ternary won't take structs so we use this aux variable
81
+ vector <uint32_t, 4 > toExchange = bit_cast<vector <uint32_t, 4 >, vector <float64_t, 2 > > (topHalf ? vector <float64_t, 2 >(lo.real (), lo.imag ()) : vector <float64_t, 2 >(hi.real (), hi.imag ()));
82
+ shuffleXor<SharedMemoryAdaptor, vector <uint32_t, 4 > >::__call (toExchange, stride, sharedmemAdaptor);
83
+ vector <float64_t, 2 > exchanged = bit_cast<vector <float64_t, 2 >, vector <uint32_t, 4 > >(toExchange);
84
+ if (topHalf)
85
+ {
86
+ lo.real (exchanged.x);
87
+ lo.imag (exchanged.y);
88
+ }
89
+ else
90
+ {
91
+ hi.real (exchanged.x);
92
+ hi.imag (exchanged.y);
93
+ }
94
+ }
95
+ };
35
96
36
97
} //namespace fft
37
98
@@ -51,10 +112,10 @@ struct FFT;
51
112
template<typename Scalar, class device_capabilities>
52
113
struct FFT<2 ,false , Scalar, device_capabilities>
53
114
{
54
- template<typename SharedMemoryAccessor >
55
- static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor )
115
+ template<typename SharedMemoryAdaptor >
116
+ static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor )
56
117
{
57
- fft::exchangeValues<SharedMemoryAccessor , Scalar>(lo, hi, threadID, stride, sharedmemAccessor );
118
+ fft::exchangeValues<SharedMemoryAdaptor , Scalar>:: __call (lo, hi, threadID, stride, sharedmemAdaptor );
58
119
59
120
// Get twiddle with k = threadID mod stride, halfN = stride
60
121
hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
@@ -77,18 +138,25 @@ struct FFT<2,false, Scalar, device_capabilities>
77
138
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
78
139
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
79
140
{
141
+ // Set up the memory adaptor
142
+ MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
143
+ sharedmemAdaptor.accessor = sharedmemAccessor;
144
+
80
145
// special first iteration
81
146
hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
82
147
83
148
// Run bigger steps until Subgroup-sized
84
149
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1 ; stride > glsl::gl_SubgroupSize (); stride >>= 1 )
85
150
{
86
- FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor );
87
- sharedmemAccessor .workgroupExecutionAndMemoryBarrier ();
151
+ FFT_loop< MemoryAdaptor< SharedMemoryAccessor> > (stride, lo, hi, threadID, sharedmemAdaptor );
152
+ sharedmemAdaptor .workgroupExecutionAndMemoryBarrier ();
88
153
}
89
154
90
155
// special last workgroup-shuffle
91
- fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAccessor);
156
+ fft::exchangeValues<MemoryAdaptor<SharedMemoryAccessor>, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
157
+
158
+ // Remember to update the accessor's state
159
+ sharedmemAccessor = sharedmemAdaptor.accessor;
92
160
}
93
161
94
162
// Subgroup-sized FFT
@@ -106,13 +174,13 @@ struct FFT<2,false, Scalar, device_capabilities>
106
174
template<typename Scalar, class device_capabilities>
107
175
struct FFT<2 ,true , Scalar, device_capabilities>
108
176
{
109
- template<typename SharedMemoryAccessor >
110
- static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor )
177
+ template<typename SharedMemoryAdaptor >
178
+ static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor )
111
179
{
112
180
// Get twiddle with k = threadID mod stride, halfN = stride
113
181
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
114
182
115
- fft::exchangeValues<SharedMemoryAccessor , Scalar>(lo, hi, threadID, stride, sharedmemAccessor );
183
+ fft::exchangeValues<SharedMemoryAdaptor , Scalar>:: __call (lo, hi, threadID, stride, sharedmemAdaptor );
116
184
}
117
185
118
186
@@ -135,22 +203,29 @@ struct FFT<2,true, Scalar, device_capabilities>
135
203
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
136
204
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
137
205
{
206
+ // Set up the memory adaptor
207
+ MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
208
+ sharedmemAdaptor.accessor = sharedmemAccessor;
209
+
138
210
// special first workgroup-shuffle
139
- fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAccessor );
211
+ fft::exchangeValues<MemoryAdaptor< SharedMemoryAccessor> , Scalar>:: __call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor );
140
212
141
213
// The bigger steps
142
214
for (uint32_t stride = glsl::gl_SubgroupSize () << 1 ; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1 )
143
215
{
144
216
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
145
- sharedmemAccessor .workgroupExecutionAndMemoryBarrier ();
146
- FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor );
217
+ sharedmemAdaptor .workgroupExecutionAndMemoryBarrier ();
218
+ FFT_loop< MemoryAdaptor< SharedMemoryAccessor> > (stride, lo, hi, threadID, sharedmemAdaptor );
147
219
}
148
220
149
221
// special last iteration
150
222
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
151
223
divides_assign< complex_t<Scalar> > divAss;
152
- divAss (lo, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ());
153
- divAss (hi, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ());
224
+ divAss (lo, Scalar (_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ()));
225
+ divAss (hi, Scalar (_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ()));
226
+
227
+ // Remember to update the accessor's state
228
+ sharedmemAccessor = sharedmemAdaptor.accessor;
154
229
}
155
230
156
231
// Put values back in global mem
@@ -166,10 +241,10 @@ struct FFT<K, false, Scalar, device_capabilities>
166
241
template<typename Accessor, typename SharedMemoryAccessor>
167
242
static enable_if_t< (mpl::is_pot_v<K> && K > 2 ), void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
168
243
{
169
- for (uint32_t stride = (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1 )
244
+ for (uint32_t stride = (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1 )
170
245
{
171
246
//[unroll(K/2)]
172
- for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
247
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
173
248
{
174
249
const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
175
250
const uint32_t hiIx = loIx | stride;
@@ -223,7 +298,7 @@ struct FFT<K, true, Scalar, device_capabilities>
223
298
{
224
299
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
225
300
//[unroll(K/2)]
226
- for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
301
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
227
302
{
228
303
const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
229
304
const uint32_t hiIx = loIx | stride;
@@ -235,11 +310,11 @@ struct FFT<K, true, Scalar, device_capabilities>
235
310
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true ,Scalar>(virtualThreadID & (stride - 1 ), stride), lo,hi);
236
311
237
312
// Divide by special factor at the end
238
- if ( (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_ == stride)
313
+ if ( (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_ == stride)
239
314
{
240
315
divides_assign< complex_t<Scalar> > divAss;
241
- divAss (lo, K >> 1 );
242
- divAss (hi, K >> 1 );
316
+ divAss (lo, K / 2 );
317
+ divAss (hi, K / 2 );
243
318
}
244
319
245
320
accessor.set (loIx, lo);
0 commit comments