Skip to content

Commit e618e58

Browse files
committed
Yet more utils, such as bitreversal
1 parent b31705d commit e618e58

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,47 @@
11
#ifndef _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
22
#define _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
33

4-
#include "nbl/builtin/hlsl/complex.hlsl"
54
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5+
6+
#ifndef __HLSL_VERSION
7+
#include <nbl/core/math/intutil.h>
8+
9+
namespace nbl
10+
{
11+
namespace hlsl
12+
{
13+
namespace fft
14+
{
15+
16+
static inline uint32_t3 padDimensions(uint32_t3 dimensions, std::span<uint16_t> axes, bool realFFT = false)
17+
{
18+
uint16_t axisCount = 0;
19+
for (auto i : axes)
20+
{
21+
dimensions[i] = core::roundUpToPoT(dimensions[i]);
22+
if (realFFT && !axisCount++)
23+
dimensions[i] /= 2;
24+
}
25+
return dimensions;
26+
}
27+
28+
static inline uint64_t getOutputBufferSize(const uint32_t3& inputDimensions, uint32_t numChannels, std::span<uint16_t> axes, bool realFFT = false, bool halfFloats = false)
29+
{
30+
auto paddedDims = padDimensions(inputDimensions, axes);
31+
uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * numChannels;
32+
return 2 * numberOfComplexElements * (halfFloats ? sizeof(float16_t) : sizeof(float32_t));
33+
}
34+
35+
36+
}
37+
}
38+
}
39+
40+
#else
41+
42+
#include "nbl/builtin/hlsl/complex.hlsl"
643
#include "nbl/builtin/hlsl/numbers.hlsl"
44+
#include "nbl/builtin/hlsl/concepts.hlsl"
745

846
namespace nbl
947
{
@@ -53,8 +91,29 @@ using DIT = DIX<true, Scalar>;
5391

5492
template<typename Scalar>
5593
using DIF = DIX<false, Scalar>;
94+
95+
// ------------------------------------------------- Utils ---------------------------------------------------------
96+
//
97+
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
98+
template<typename Scalar>
99+
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
100+
{
101+
complex_t<Scalar> x = (lo + conj(hi)) * Scalar(0.5);
102+
hi = rotateRight<Scalar>(lo - conj(hi)) * Scalar(0.5);
103+
lo = x;
56104
}
105+
106+
// Bit-reverses T as a binary string of length given by Bits
107+
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_integral_v<T> && Bits <= sizeof(T) * 8)
108+
T bitReverse(T value)
109+
{
110+
return glsl::bitfieldReverse<uint32_t>(value) >> (sizeof(T) * 8 - Bits);
57111
}
112+
58113
}
114+
}
115+
}
116+
117+
#endif
59118

60119
#endif

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,37 @@
11
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
22
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
33

4+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5+
#include <nbl/builtin/hlsl/fft/common.hlsl>
6+
7+
#ifndef __HLSL_VERSION
8+
#include <nbl/video/IPhysicalDevice.h>
9+
10+
namespace nbl
11+
{
12+
namespace hlsl
13+
{
14+
namespace workgroup
15+
{
16+
namespace fft
17+
{
18+
19+
inline std::pair<uint16_t, uint16_t> optimalFFTParameters(const video::ILogicalDevice* device, uint32_t inputArrayLength)
20+
{
21+
uint32_t maxWorkgroupSize = *device->getPhysicalDevice()->getLimits().maxWorkgroupSize;
22+
// This is the logic found in core::roundUpToPoT to get the log2
23+
uint16_t workgroupSizeLog2 = 1u + hlsl::findMSB(core::min(inputArrayLength / 2, maxWorkgroupSize) - 1u);
24+
uint16_t elementPerInvocationLog2 = 1u + hlsl::findMSB(core::max((inputArrayLength >> workgroupSizeLog2) - 1u, 1u));
25+
return { elementPerInvocationLog2, workgroupSizeLog2 };
26+
}
27+
28+
}
29+
}
30+
}
31+
}
32+
33+
#else
34+
435
#include "nbl/builtin/hlsl/subgroup/fft.hlsl"
536
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
637
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
@@ -91,15 +122,6 @@ namespace impl
91122
}
92123
} //namespace impl
93124

94-
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
95-
template<typename Scalar>
96-
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
97-
{
98-
complex_t<Scalar> x = (lo + conj(hi)) * Scalar(0.5);
99-
hi = rotateRight<Scalar>(lo - conj(hi)) * Scalar(0.5);
100-
lo = x;
101-
}
102-
103125
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
104126
struct FFTIndexingUtils
105127
{
@@ -425,4 +447,7 @@ struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSi
425447
}
426448
}
427449

450+
451+
#endif
452+
428453
#endif

0 commit comments

Comments
 (0)