8
8
#include "nbl/builtin/hlsl/mpl.hlsl"
9
9
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
10
10
#include "nbl/builtin/hlsl/bit.hlsl"
11
+ #include "nbl/builtin/hlsl/concepts.hlsl"
11
12
12
13
// Caveats
13
14
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
@@ -90,10 +91,6 @@ namespace impl
90
91
}
91
92
} //namespace impl
92
93
93
- // Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
94
- template <typename scalar_t, uint16_t WorkgroupSize>
95
- NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
96
-
97
94
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
98
95
template<typename Scalar>
99
96
void unpack (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi)
@@ -103,7 +100,7 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
103
100
lo = x;
104
101
}
105
102
106
- template<uint16_t ElementsPerInvocation , uint16_t WorkgroupSize >
103
+ template<uint16_t ElementsPerInvocationLog2 , uint16_t WorkgroupSizeLog2 >
107
104
struct FFTIndexingUtils
108
105
{
109
106
// This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
@@ -132,16 +129,36 @@ struct FFTIndexingUtils
132
129
return getNablaIndex (getDFTMirrorIndex (getDFTIndex (idx)));
133
130
}
134
131
135
- NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136
- NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137
- NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t (WorkgroupSize) * uint32_t (ElementsPerInvocation);
132
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + WorkgroupSizeLog2;
133
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t (1 ) << FFTSizeLog2;
138
134
};
139
135
140
136
} //namespace fft
141
137
142
- // ----------------------------------- End Utils -----------------------------------------------
138
+ // ----------------------------------- End Utils --------------------------------------------------------------
139
+
140
+ namespace fft
141
+ {
142
+
143
+ template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename _Scalar NBL_PRIMARY_REQUIRES (_ElementsPerInvocationLog2 > 0 && _WorkgroupSizeLog2 >= 5 )
144
+ struct ConstevalParameters
145
+ {
146
+ using scalar_t = _Scalar;
147
+
148
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
149
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
150
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t TotalSize = uint32_t (1 ) << (ElementsPerInvocationLog2 + WorkgroupSizeLog2);
143
151
144
- template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void >
152
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = uint16_t (1 ) << ElementsPerInvocationLog2;
153
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t (1 ) << WorkgroupSizeLog2;
154
+
155
+ // Required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
156
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) << WorkgroupSizeLog2;
157
+ };
158
+
159
+ } //namespace fft
160
+
161
+ template<bool Inverse, typename consteval_params_t, class device_capabilities=void >
145
162
struct FFT;
146
163
147
164
// For the FFT methods below, we assume:
@@ -161,9 +178,11 @@ struct FFT;
161
178
// * void workgroupExecutionAndMemoryBarrier();
162
179
163
180
// 2 items per invocation forward specialization
164
- template<uint16_t WorkgroupSize , typename Scalar, class device_capabilities>
165
- struct FFT<2 , false , WorkgroupSize, Scalar, device_capabilities>
181
+ template<uint16_t WorkgroupSizeLog2 , typename Scalar, class device_capabilities>
182
+ struct FFT<false , fft::ConstevalParameters< 1 , WorkgroupSizeLog2, Scalar> , device_capabilities>
166
183
{
184
+ using consteval_params_t = fft::ConstevalParameters<1 , WorkgroupSizeLog2, Scalar>;
185
+
167
186
template<typename SharedMemoryAdaptor>
168
187
static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
169
188
{
@@ -177,6 +196,8 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
177
196
template<typename Accessor, typename SharedMemoryAccessor>
178
197
static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
179
198
{
199
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
200
+
180
201
// Compute the indices only once
181
202
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
182
203
const uint32_t loIx = threadID;
@@ -222,12 +243,12 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
222
243
}
223
244
};
224
245
225
-
226
-
227
246
// 2 items per invocation inverse specialization
228
- template<uint16_t WorkgroupSize , typename Scalar, class device_capabilities>
229
- struct FFT<2 , true , WorkgroupSize, Scalar, device_capabilities>
247
+ template<uint16_t WorkgroupSizeLog2 , typename Scalar, class device_capabilities>
248
+ struct FFT<true , fft::ConstevalParameters< 1 , WorkgroupSizeLog2, Scalar> , device_capabilities>
230
249
{
250
+ using consteval_params_t = fft::ConstevalParameters<1 , WorkgroupSizeLog2, Scalar>;
251
+
231
252
template<typename SharedMemoryAdaptor>
232
253
static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
233
254
{
@@ -241,6 +262,8 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
241
262
template<typename Accessor, typename SharedMemoryAccessor>
242
263
static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
243
264
{
265
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
266
+
244
267
// Compute the indices only once
245
268
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
246
269
const uint32_t loIx = threadID;
@@ -291,17 +314,23 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
291
314
};
292
315
293
316
// Forward FFT
294
- template<uint32_t K , uint16_t WorkgroupSize , typename Scalar, class device_capabilities>
295
- struct FFT<K, false , WorkgroupSize , Scalar, device_capabilities>
317
+ template<uint16_t ElementsPerInvocationLog2 , uint16_t WorkgroupSizeLog2 , typename Scalar, class device_capabilities>
318
+ struct FFT<false , fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2 , Scalar> , device_capabilities>
296
319
{
320
+ using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
321
+ using small_fft_consteval_params_t = fft::ConstevalParameters<1 , WorkgroupSizeLog2, Scalar>;
322
+
297
323
template<typename Accessor, typename SharedMemoryAccessor>
298
- static enable_if_t< (mpl::is_pot_v<K> && K > 2 ), void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
324
+ static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
299
325
{
326
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
327
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = consteval_params_t::ElementsPerInvocation;
328
+
300
329
[unroll]
301
- for (uint32_t stride = (K / 2 ) * WorkgroupSize; stride > WorkgroupSize; stride >>= 1 )
330
+ for (uint32_t stride = (ElementsPerInvocation / 2 ) * WorkgroupSize; stride > WorkgroupSize; stride >>= 1 )
302
331
{
303
332
[unroll]
304
- for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * WorkgroupSize; virtualThreadID += WorkgroupSize)
333
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (ElementsPerInvocation / 2 ) * WorkgroupSize; virtualThreadID += WorkgroupSize)
305
334
{
306
335
const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
307
336
const uint32_t hiIx = loIx | stride;
@@ -318,47 +347,53 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
318
347
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
319
348
}
320
349
321
- // do K /2 small workgroup FFTs
350
+ // do ElementsPerInvocation /2 small workgroup FFTs
322
351
accessor_adaptors::Offset<Accessor> offsetAccessor;
323
352
offsetAccessor.accessor = accessor;
324
353
[unroll]
325
- for (uint32_t k = 0 ; k < K ; k += 2 )
354
+ for (uint32_t k = 0 ; k < ElementsPerInvocation ; k += 2 )
326
355
{
327
356
if (k)
328
357
sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
329
358
offsetAccessor.offset = WorkgroupSize*k;
330
- FFT<2 , false , WorkgroupSize, Scalar , device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
359
+ FFT<false , small_fft_consteval_params_t , device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
331
360
}
332
361
accessor = offsetAccessor.accessor;
333
362
}
334
363
};
335
364
336
365
// Inverse FFT
337
- template<uint32_t K , uint16_t WorkgroupSize , typename Scalar, class device_capabilities>
338
- struct FFT<K, true , WorkgroupSize , Scalar, device_capabilities>
366
+ template<uint16_t ElementsPerInvocationLog2 , uint16_t WorkgroupSizeLog2 , typename Scalar, class device_capabilities>
367
+ struct FFT<true , fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2 , Scalar> , device_capabilities>
339
368
{
369
+ using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
370
+ using small_fft_consteval_params_t = fft::ConstevalParameters<1 , WorkgroupSizeLog2, Scalar>;
371
+
340
372
template<typename Accessor, typename SharedMemoryAccessor>
341
- static enable_if_t< (mpl::is_pot_v<K> && K > 2 ), void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
373
+ static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
342
374
{
375
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
376
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = consteval_params_t::ElementsPerInvocation;
377
+
343
378
// do K/2 small workgroup FFTs
344
379
accessor_adaptors::Offset<Accessor> offsetAccessor;
345
380
offsetAccessor.accessor = accessor;
346
381
[unroll]
347
- for (uint32_t k = 0 ; k < K ; k += 2 )
382
+ for (uint32_t k = 0 ; k < ElementsPerInvocation ; k += 2 )
348
383
{
349
384
if (k)
350
385
sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
351
386
offsetAccessor.offset = WorkgroupSize*k;
352
- FFT<2 , true , WorkgroupSize, Scalar , device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
387
+ FFT<true , small_fft_consteval_params_t , device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
353
388
}
354
389
accessor = offsetAccessor.accessor;
355
390
356
391
[unroll]
357
- for (uint32_t stride = 2 * WorkgroupSize; stride < K * WorkgroupSize; stride <<= 1 )
392
+ for (uint32_t stride = 2 * WorkgroupSize; stride < ElementsPerInvocation * WorkgroupSize; stride <<= 1 )
358
393
{
359
394
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
360
395
[unroll]
361
- for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * WorkgroupSize; virtualThreadID += WorkgroupSize)
396
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (ElementsPerInvocation / 2 ) * WorkgroupSize; virtualThreadID += WorkgroupSize)
362
397
{
363
398
const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
364
399
const uint32_t hiIx = loIx | stride;
@@ -370,11 +405,11 @@ struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
370
405
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true ,Scalar>(virtualThreadID & (stride - 1 ), stride), lo,hi);
371
406
372
407
// Divide by special factor at the end
373
- if ( (K / 2 ) * WorkgroupSize == stride)
408
+ if ( (ElementsPerInvocation / 2 ) * WorkgroupSize == stride)
374
409
{
375
410
divides_assign< complex_t<Scalar> > divAss;
376
- divAss (lo, K / 2 );
377
- divAss (hi, K / 2 );
411
+ divAss (lo, ElementsPerInvocation / 2 );
412
+ divAss (hi, ElementsPerInvocation / 2 );
378
413
}
379
414
380
415
accessor.set (loIx, lo);
0 commit comments