3
3
4
4
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5
5
#include <nbl/builtin/hlsl/complex.hlsl>
6
+ #include <nbl/builtin/hlsl/concepts.hlsl>
6
7
7
8
#ifndef __HLSL_VERSION
8
9
#include <nbl/core/math/intutil.h>
@@ -14,22 +15,31 @@ namespace hlsl
14
15
namespace fft
15
16
{
16
17
17
- static inline uint32_t3 padDimensions (uint32_t3 dimensions, std::span<uint16_t> axes, bool realFFT = false )
18
+ // template parameter N controls the number of dimensions of the input
19
+ // template parameter M controls the number of dimensions to pad up to PoT
20
+ // "axes" indicates which dimensions to pad up to PoT
21
+ template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES (M <= N)
22
+ NBL_FORCE_INLINE vector <uint64_t, 3 > padDimensions (NBL_CONST_REF_ARG (vector <uint32_t, N>) dimensions, NBL_CONST_REF_ARG (vector <uint16_t, M>) axes, bool realFFT = false )
18
23
{
24
+ vector <uint32_t, N> newDimensions = dimensions;
19
25
uint16_t axisCount = 0 ;
20
- for (auto i : axes )
26
+ for (uint16_t i = 0u; i < M; i++ )
21
27
{
22
- dimensions [i] = core::roundUpToPoT (dimensions [i]);
28
+ newDimensions [i] = core::roundUpToPoT (newDimensions [i]);
23
29
if (realFFT && !axisCount++)
24
- dimensions [i] /= 2 ;
30
+ newDimensions [i] /= 2 ;
25
31
}
26
- return dimensions ;
32
+ return newDimensions ;
27
33
}
28
34
29
- static inline uint64_t getOutputBufferSize (const uint32_t3& inputDimensions, uint32_t numChannels, std::span<uint16_t> axes, bool realFFT = false , bool halfFloats = false )
35
+ // template parameter N controls the number of dimensions of the input
36
+ // template parameter M controls the number of dimensions we run an FFT along AND store the result
37
+ // "axes" indicates which dimensions we run an FFT along AND store the result
38
+ template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES (M <= N)
39
+ NBL_FORCE_INLINE uint64_t getOutputBufferSize (NBL_CONST_REF_ARG (vector <uint32_t, N>) inputDimensions, uint32_t numChannels, NBL_CONST_REF_ARG (vector <uint16_t, M>) axes, bool realFFT = false , bool halfFloats = false )
30
40
{
31
- auto paddedDims = padDimensions (inputDimensions, axes);
32
- uint64_t numberOfComplexElements = paddedDims[0 ] * paddedDims[1 ] * paddedDims[2 ] * numChannels;
41
+ const vector <uint64_t, 3 > paddedDims = padDimensions<N, M> (inputDimensions, axes);
42
+ const uint64_t numberOfComplexElements = paddedDims[0 ] * paddedDims[1 ] * paddedDims[2 ] * uint64_t ( numChannels) ;
33
43
return numberOfComplexElements * (halfFloats ? sizeof (complex_t<float16_t>) : sizeof (complex_t<float32_t>));
34
44
}
35
45
0 commit comments