Skip to content

Commit dcc537e

Browse files
committed
More changes following Bloom PR review
1 parent e2df09a commit dcc537e

File tree

2 files changed

+70
-35
lines changed

2 files changed

+70
-35
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "nbl/builtin/hlsl/mpl.hlsl"
99
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
1010
#include "nbl/builtin/hlsl/bit.hlsl"
11+
#include "nbl/builtin/hlsl/concepts.hlsl"
1112

1213
// Caveats
1314
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
@@ -90,10 +91,6 @@ namespace impl
9091
}
9192
} //namespace impl
9293

93-
// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
94-
template <typename scalar_t, uint16_t WorkgroupSize>
95-
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) * WorkgroupSize;
96-
9794
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
9895
template<typename Scalar>
9996
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
@@ -103,7 +100,7 @@ void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi
103100
lo = x;
104101
}
105102

106-
template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
103+
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
107104
struct FFTIndexingUtils
108105
{
109106
// This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
@@ -132,16 +129,36 @@ struct FFTIndexingUtils
132129
return getNablaIndex(getDFTMirrorIndex(getDFTIndex(idx)));
133130
}
134131

135-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136-
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137-
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(WorkgroupSize) * uint32_t(ElementsPerInvocation);
132+
NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + WorkgroupSizeLog2;
133+
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t(1) << FFTSizeLog2;
138134
};
139135

140136
} //namespace fft
141137

142-
// ----------------------------------- End Utils -----------------------------------------------
138+
// ----------------------------------- End Utils --------------------------------------------------------------
139+
140+
namespace fft
141+
{
142+
143+
template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename _Scalar NBL_PRIMARY_REQUIRES(_ElementsPerInvocationLog2 > 0 && _WorkgroupSizeLog2 >= 5)
144+
struct ConstevalParameters
145+
{
146+
using scalar_t = _Scalar;
147+
148+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
149+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
150+
NBL_CONSTEXPR_STATIC_INLINE uint32_t TotalSize = uint32_t(1) << (ElementsPerInvocationLog2 + WorkgroupSizeLog2);
143151

144-
template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void>
152+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = uint16_t(1) << ElementsPerInvocationLog2;
153+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2;
154+
155+
// Required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
156+
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) << WorkgroupSizeLog2;
157+
};
158+
159+
} //namespace fft
160+
161+
template<bool Inverse, typename consteval_params_t, class device_capabilities=void>
145162
struct FFT;
146163

147164
// For the FFT methods below, we assume:
@@ -161,9 +178,11 @@ struct FFT;
161178
// * void workgroupExecutionAndMemoryBarrier();
162179

163180
// 2 items per invocation forward specialization
164-
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
165-
struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
181+
template<uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
182+
struct FFT<false, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_capabilities>
166183
{
184+
using consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
185+
167186
template<typename SharedMemoryAdaptor>
168187
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
169188
{
@@ -177,6 +196,8 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
177196
template<typename Accessor, typename SharedMemoryAccessor>
178197
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
179198
{
199+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
200+
180201
// Compute the indices only once
181202
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
182203
const uint32_t loIx = threadID;
@@ -222,12 +243,12 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
222243
}
223244
};
224245

225-
226-
227246
// 2 items per invocation inverse specialization
228-
template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
229-
struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
247+
template<uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
248+
struct FFT<true, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_capabilities>
230249
{
250+
using consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
251+
231252
template<typename SharedMemoryAdaptor>
232253
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
233254
{
@@ -241,6 +262,8 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
241262
template<typename Accessor, typename SharedMemoryAccessor>
242263
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
243264
{
265+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
266+
244267
// Compute the indices only once
245268
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
246269
const uint32_t loIx = threadID;
@@ -291,17 +314,23 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
291314
};
292315

293316
// Forward FFT
294-
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
295-
struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
317+
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
318+
struct FFT<false, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>, device_capabilities>
296319
{
320+
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
321+
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
322+
297323
template<typename Accessor, typename SharedMemoryAccessor>
298-
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
324+
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
299325
{
326+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
327+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = consteval_params_t::ElementsPerInvocation;
328+
300329
[unroll]
301-
for (uint32_t stride = (K / 2) * WorkgroupSize; stride > WorkgroupSize; stride >>= 1)
330+
for (uint32_t stride = (ElementsPerInvocation / 2) * WorkgroupSize; stride > WorkgroupSize; stride >>= 1)
302331
{
303332
[unroll]
304-
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
333+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (ElementsPerInvocation / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
305334
{
306335
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
307336
const uint32_t hiIx = loIx | stride;
@@ -318,47 +347,53 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
318347
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
319348
}
320349

321-
// do K/2 small workgroup FFTs
350+
// do ElementsPerInvocation/2 small workgroup FFTs
322351
accessor_adaptors::Offset<Accessor> offsetAccessor;
323352
offsetAccessor.accessor = accessor;
324353
[unroll]
325-
for (uint32_t k = 0; k < K; k += 2)
354+
for (uint32_t k = 0; k < ElementsPerInvocation; k += 2)
326355
{
327356
if (k)
328357
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
329358
offsetAccessor.offset = WorkgroupSize*k;
330-
FFT<2,false, WorkgroupSize, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
359+
FFT<false, small_fft_consteval_params_t, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
331360
}
332361
accessor = offsetAccessor.accessor;
333362
}
334363
};
335364

336365
// Inverse FFT
337-
template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
338-
struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
366+
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2, typename Scalar, class device_capabilities>
367+
struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>, device_capabilities>
339368
{
369+
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
370+
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
371+
340372
template<typename Accessor, typename SharedMemoryAccessor>
341-
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
373+
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
342374
{
375+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
376+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = consteval_params_t::ElementsPerInvocation;
377+
343378
// do K/2 small workgroup FFTs
344379
accessor_adaptors::Offset<Accessor> offsetAccessor;
345380
offsetAccessor.accessor = accessor;
346381
[unroll]
347-
for (uint32_t k = 0; k < K; k += 2)
382+
for (uint32_t k = 0; k < ElementsPerInvocation; k += 2)
348383
{
349384
if (k)
350385
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
351386
offsetAccessor.offset = WorkgroupSize*k;
352-
FFT<2,true, WorkgroupSize, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
387+
FFT<true, small_fft_consteval_params_t, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
353388
}
354389
accessor = offsetAccessor.accessor;
355390

356391
[unroll]
357-
for (uint32_t stride = 2 * WorkgroupSize; stride < K * WorkgroupSize; stride <<= 1)
392+
for (uint32_t stride = 2 * WorkgroupSize; stride < ElementsPerInvocation * WorkgroupSize; stride <<= 1)
358393
{
359394
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
360395
[unroll]
361-
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
396+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (ElementsPerInvocation / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
362397
{
363398
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
364399
const uint32_t hiIx = loIx | stride;
@@ -370,11 +405,11 @@ struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
370405
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride), lo,hi);
371406

372407
// Divide by special factor at the end
373-
if ( (K / 2) * WorkgroupSize == stride)
408+
if ( (ElementsPerInvocation / 2) * WorkgroupSize == stride)
374409
{
375410
divides_assign< complex_t<Scalar> > divAss;
376-
divAss(lo, K / 2);
377-
divAss(hi, K / 2);
411+
divAss(lo, ElementsPerInvocation / 2);
412+
divAss(hi, ElementsPerInvocation / 2);
378413
}
379414

380415
accessor.set(loIx, lo);

0 commit comments

Comments
 (0)