2
2
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
3
3
4
4
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5
+ #include <nbl/builtin/hlsl/concepts.hlsl>
5
6
#include <nbl/builtin/hlsl/fft/common.hlsl>
6
7
8
+ // ------------------------------- COMMON -----------------------------------------
9
+
10
+ namespace nbl
11
+ {
12
+ namespace hlsl
13
+ {
14
+ namespace workgroup
15
+ {
16
+ namespace fft
17
+ {
18
+
19
+ template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename _Scalar NBL_PRIMARY_REQUIRES (_ElementsPerInvocationLog2 > 0 && _WorkgroupSizeLog2 >= 5 )
20
+ struct ConstevalParameters
21
+ {
22
+ using scalar_t = _Scalar;
23
+
24
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
25
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
26
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t TotalSize = uint32_t (1 ) << (ElementsPerInvocationLog2 + WorkgroupSizeLog2);
27
+
28
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = uint16_t (1 ) << ElementsPerInvocationLog2;
29
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t (1 ) << WorkgroupSizeLog2;
30
+
31
+ // Required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
32
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) << WorkgroupSizeLog2;
33
+ };
34
+
35
+ }
36
+ }
37
+ }
38
+ }
39
+ // ------------------------------- END COMMON -----------------------------------------
40
+
41
+ // ------------------------------- CPP ONLY -------------------------------------------
7
42
#ifndef __HLSL_VERSION
8
43
#include <nbl/video/IPhysicalDevice.h>
9
44
@@ -30,6 +65,9 @@ inline std::pair<uint16_t, uint16_t> optimalFFTParameters(const video::ILogicalD
30
65
}
31
66
}
32
67
68
+ // ------------------------------- END CPP ONLY -------------------------------------------
69
+
70
+ // ------------------------------- HLSL ONLY ----------------------------------------------
33
71
#else
34
72
35
73
#include "nbl/builtin/hlsl/subgroup/fft.hlsl"
@@ -39,7 +77,6 @@ inline std::pair<uint16_t, uint16_t> optimalFFTParameters(const video::ILogicalD
39
77
#include "nbl/builtin/hlsl/mpl.hlsl"
40
78
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
41
79
#include "nbl/builtin/hlsl/bit.hlsl"
42
- #include "nbl/builtin/hlsl/concepts.hlsl"
43
80
44
81
// Caveats
45
82
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
@@ -157,19 +194,35 @@ struct FFTIndexingUtils
157
194
// but also the thread holding said mirror value will at the same time be trying to unpack `NFFT[someOtherIndex]` and need the mirror value of that.
158
195
// As long as this unpacking is happening concurrently and in order (meaning the local element index - the higher bits - of `globalElementIndex` and `someOtherIndex` is the
159
196
// same) then this function returns both the SubgroupContiguousIndex of the other thread AND the local element index of *the mirror* of `someOtherIndex`
160
- struct NablaMirrorTradeInfo
197
+ struct NablaMirrorLocalInfo
161
198
{
162
199
uint32_t otherThreadID;
163
200
uint32_t mirrorLocalIndex;
164
201
};
165
202
166
- static NablaMirrorTradeInfo getNablaMirrorTradeInfo (uint32_t localElementIndex)
203
+ static NablaMirrorLocalInfo getNablaMirrorLocalInfo (uint32_t localElementIndex)
167
204
{
168
205
const uint32_t globalElementIndex = localElementIndex * WorkgroupSize | workgroup::SubgroupContiguousIndex ();
169
206
const uint32_t otherElementIndex = FFTIndexingUtils::getNablaMirrorIndex (globalElementIndex);
170
207
const uint32_t mirrorLocalIndex = otherElementIndex / WorkgroupSize;
171
208
const uint32_t otherThreadID = otherElementIndex & (WorkgroupSize - 1 );
172
- NablaMirrorTradeInfo info = { otherThreadID, mirrorLocalIndex };
209
+ const NablaMirrorLocalInfo info = { otherThreadID, mirrorLocalIndex };
210
+ return info;
211
+ }
212
+
213
+ // Like the above, but return global indices instead.
214
+ struct NablaMirrorGlobalInfo
215
+ {
216
+ uint32_t otherThreadID;
217
+ uint32_t mirrorGlobalIndex;
218
+ };
219
+
220
+ static NablaMirrorGlobalInfo getNablaMirrorGlobalInfo (uint32_t globalElementIndex)
221
+ {
222
+ const uint32_t otherElementIndex = FFTIndexingUtils::getNablaMirrorIndex (globalElementIndex);
223
+ const uint32_t mirrorGlobalIndex = glsl::bitfieldInsert<uint32_t>(otherElementIndex, workgroup::SubgroupContiguousIndex (), 0 , uint32_t (WorkgroupSizeLog2));
224
+ const uint32_t otherThreadID = otherElementIndex & (WorkgroupSize - 1 );
225
+ const NablaMirrorGlobalInfo info = { otherThreadID, mirrorGlobalIndex };
173
226
return info;
174
227
}
175
228
@@ -178,31 +231,39 @@ struct FFTIndexingUtils
178
231
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = uint32_t (1 ) << WorkgroupSizeLog2;
179
232
};
180
233
181
- } //namespace fft
182
-
183
- // ----------------------------------- End Utils --------------------------------------------------------------
184
-
185
- namespace fft
234
+ template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
235
+ struct FFTMirrorTradeUtils
186
236
{
237
+ using indexing_utils_t = FFTIndexingUtils<ElementsPerInvocationLog2, WorkgroupSizeLog2>;
238
+ using mirror_info_t = typename indexing_utils_t::NablaMirrorGlobalInfo;
239
+ // If trading elements when, for example, unpacking real FFTs, you might do so from within your accessor or from outside.
240
+ // If doing so from within your accessor, particularly if using a preloaded accessor, you might want to do this yourself by
241
+ // using FFTIndexingUtils::getNablaMirrorTradeInfo and trading the elements yourself (an example of how to set this up is given in
242
+ // the FFT Bloom example, in the `fft_mirror_common.hlsl` file).
243
+ // If you're doing this from outside your preloaded accessor then you might want to use this method instead.
244
+ // Note: you can still pass a preloaded accessor as `arrayAccessor` here, it's just that you're going to be doing extra computations for the indices.
245
+ template<typename scalar_t, typename fft_array_accessor_t, typename shared_memory_adaptor_t>
246
+ static complex_t<scalar_t> getNablaMirror (uint32_t globalElementIndex, fft_array_accessor_t arrayAccessor, shared_memory_adaptor_t sharedmemAdaptor)
247
+ {
248
+ const mirror_info_t mirrorInfo = indexing_utils_t::getNablaMirrorGlobalInfo (globalElementIndex);
249
+ complex_t<scalar_t> toTrade = arrayAccessor.get (mirrorInfo.mirrorGlobalIndex);
250
+ vector <scalar_t, 2 > toTradeVector = { toTrade.real (), toTrade.imag () };
251
+ workgroup::Shuffle<shared_memory_adaptor_t, vector <scalar_t, 2 > >::__call (toTradeVector, mirrorInfo.otherThreadID, sharedmemAdaptor);
252
+ toTrade.real (toTradeVector.x);
253
+ toTrade.imag (toTradeVector.y);
254
+ return toTrade;
255
+ }
187
256
188
- template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename _Scalar NBL_PRIMARY_REQUIRES (_ElementsPerInvocationLog2 > 0 && _WorkgroupSizeLog2 >= 5 )
189
- struct ConstevalParameters
190
- {
191
- using scalar_t = _Scalar;
257
+ NBL_CONSTEXPR_STATIC_INLINE indexing_utils_t IndexingUtils;
258
+ };
192
259
193
- NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
194
- NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
195
- NBL_CONSTEXPR_STATIC_INLINE uint32_t TotalSize = uint32_t (1 ) << (ElementsPerInvocationLog2 + WorkgroupSizeLog2);
196
260
197
- NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = uint16_t (1 ) << ElementsPerInvocationLog2;
198
- NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t (1 ) << WorkgroupSizeLog2;
199
261
200
- // Required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
201
- NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) << WorkgroupSizeLog2;
202
- };
203
262
204
263
} //namespace fft
205
264
265
+ // ----------------------------------- End Utils --------------------------------------------------------------
266
+
206
267
template<bool Inverse, typename consteval_params_t, class device_capabilities=void >
207
268
struct FFT;
208
269
@@ -470,7 +531,7 @@ struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSi
470
531
}
471
532
}
472
533
473
-
534
+ // ------------------------------- END HLSL ONLY ----------------------------------------------
474
535
#endif
475
536
476
537
#endif
0 commit comments