1
1
#ifndef _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
2
2
#define _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
3
3
4
- #include "nbl/builtin/hlsl/complex.hlsl"
5
4
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5
+
6
+ #ifndef __HLSL_VERSION
7
+ #include <nbl/core/math/intutil.h>
8
+
9
+ namespace nbl
10
+ {
11
+ namespace hlsl
12
+ {
13
+ namespace fft
14
+ {
15
+
16
+ static inline uint32_t3 padDimensions (uint32_t3 dimensions, std::span<uint16_t> axes, bool realFFT = false )
17
+ {
18
+ uint16_t axisCount = 0 ;
19
+ for (auto i : axes)
20
+ {
21
+ dimensions[i] = core::roundUpToPoT (dimensions[i]);
22
+ if (realFFT && !axisCount++)
23
+ dimensions[i] /= 2 ;
24
+ }
25
+ return dimensions;
26
+ }
27
+
28
+ static inline uint64_t getOutputBufferSize (const uint32_t3& inputDimensions, uint32_t numChannels, std::span<uint16_t> axes, bool realFFT = false , bool halfFloats = false )
29
+ {
30
+ auto paddedDims = padDimensions (inputDimensions, axes);
31
+ uint64_t numberOfComplexElements = paddedDims[0 ] * paddedDims[1 ] * paddedDims[2 ] * numChannels;
32
+ return 2 * numberOfComplexElements * (halfFloats ? sizeof (float16_t) : sizeof (float32_t));
33
+ }
34
+
35
+
36
+ }
37
+ }
38
+ }
39
+
40
+ #else
41
+
42
+ #include "nbl/builtin/hlsl/complex.hlsl"
6
43
#include "nbl/builtin/hlsl/numbers.hlsl"
44
+ #include "nbl/builtin/hlsl/concepts.hlsl"
7
45
8
46
namespace nbl
9
47
{
@@ -53,8 +91,29 @@ using DIT = DIX<true, Scalar>;
53
91
54
92
template<typename Scalar>
55
93
using DIF = DIX<false , Scalar>;
94
+
95
+ // ------------------------------------------------- Utils ---------------------------------------------------------
96
+ //
97
+ // Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
98
+ template<typename Scalar>
99
+ void unpack (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi)
100
+ {
101
+ complex_t<Scalar> x = (lo + conj (hi)) * Scalar (0.5 );
102
+ hi = rotateRight<Scalar>(lo - conj (hi)) * Scalar (0.5 );
103
+ lo = x;
56
104
}
105
+
106
+ // Bit-reverses T as a binary string of length given by Bits
107
+ template<typename T, uint16_t Bits NBL_FUNC_REQUIRES (is_integral_v<T> && Bits <= sizeof (T) * 8 )
108
+ T bitReverse (T value)
109
+ {
110
+ return glsl::bitfieldReverse<uint32_t>(value) >> (sizeof (T) * 8 - Bits);
57
111
}
112
+
58
113
}
114
+ }
115
+ }
116
+
117
+ #endif
59
118
60
119
#endif
0 commit comments