Skip to content

Commit 4a16b5d

Browse files
committed
padDimensions and getOutputBufferSize rewritten so they can be shared between cpp and hlsl. Only thing missing is to move intutils.h to a common header as well
1 parent 620e601 commit 4a16b5d

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
55
#include <nbl/builtin/hlsl/complex.hlsl>
6+
#include <nbl/builtin/hlsl/concepts.hlsl>
67

78
#ifndef __HLSL_VERSION
89
#include <nbl/core/math/intutil.h>
@@ -14,22 +15,31 @@ namespace hlsl
1415
namespace fft
1516
{
1617

17-
static inline uint32_t3 padDimensions(uint32_t3 dimensions, std::span<uint16_t> axes, bool realFFT = false)
18+
// template parameter N controls the number of dimensions of the input
19+
// template parameter M controls the number of dimensions to pad up to PoT
20+
// "axes" indicates which dimensions to pad up to PoT
21+
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
22+
NBL_FORCE_INLINE vector<uint64_t, 3> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false)
1823
{
24+
vector<uint32_t, N> newDimensions = dimensions;
1925
uint16_t axisCount = 0;
20-
for (auto i : axes)
26+
for (uint16_t i = 0u; i < M; i++)
2127
{
22-
dimensions[i] = core::roundUpToPoT(dimensions[i]);
28+
newDimensions[i] = core::roundUpToPoT(newDimensions[i]);
2329
if (realFFT && !axisCount++)
24-
dimensions[i] /= 2;
30+
newDimensions[i] /= 2;
2531
}
26-
return dimensions;
32+
return newDimensions;
2733
}
2834

29-
static inline uint64_t getOutputBufferSize(const uint32_t3& inputDimensions, uint32_t numChannels, std::span<uint16_t> axes, bool realFFT = false, bool halfFloats = false)
35+
// template parameter N controls the number of dimensions of the input
36+
// template parameter M controls the number of dimensions we run an FFT along AND store the result
37+
// "axes" indicates which dimensions we run an FFT along AND store the result
38+
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
39+
NBL_FORCE_INLINE uint64_t getOutputBufferSize(NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions, uint32_t numChannels, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false, bool halfFloats = false)
3040
{
31-
auto paddedDims = padDimensions(inputDimensions, axes);
32-
uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * numChannels;
41+
const vector<uint64_t, 3> paddedDims = padDimensions<N, M>(inputDimensions, axes);
42+
const uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * uint64_t(numChannels);
3343
return numberOfComplexElements * (halfFloats ? sizeof(complex_t<float16_t>) : sizeof(complex_t<float32_t>));
3444
}
3545

0 commit comments

Comments
 (0)