Skip to content

Commit e61ab7a

Browse files
committed
Forgot what changed
1 parent 4a16b5d commit e61ab7a

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace fft
1919
// template parameter M controls the number of dimensions to pad up to PoT
2020
// "axes" indicates which dimensions to pad up to PoT
2121
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
22-
NBL_FORCE_INLINE vector<uint64_t, 3> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false)
22+
inline vector<uint64_t, 3> padDimensions(NBL_CONST_REF_ARG(vector<uint32_t, N>) dimensions, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false)
2323
{
2424
vector<uint32_t, N> newDimensions = dimensions;
2525
uint16_t axisCount = 0;
@@ -36,7 +36,7 @@ NBL_FORCE_INLINE vector<uint64_t, 3> padDimensions(NBL_CONST_REF_ARG(vector<uint
3636
// template parameter M controls the number of dimensions we run an FFT along AND store the result
3737
// "axes" indicates which dimensions we run an FFT along AND store the result
3838
template <uint16_t N, uint16_t M NBL_FUNC_REQUIRES(M <= N)
39-
NBL_FORCE_INLINE uint64_t getOutputBufferSize(NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions, uint32_t numChannels, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false, bool halfFloats = false)
39+
inline uint64_t getOutputBufferSize(NBL_CONST_REF_ARG(vector<uint32_t, N>) inputDimensions, uint32_t numChannels, NBL_CONST_REF_ARG(vector<uint16_t, M>) axes, bool realFFT = false, bool halfFloats = false)
4040
{
4141
const vector<uint64_t, 3> paddedDims = padDimensions<N, M>(inputDimensions, axes);
4242
const uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * uint64_t(numChannels);

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
2-
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
3-
41
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
52
#include <nbl/builtin/hlsl/concepts.hlsl>
63
#include <nbl/builtin/hlsl/fft/common.hlsl>
74

5+
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
6+
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
7+
88
// ------------------------------- COMMON -----------------------------------------
99

1010
namespace nbl
@@ -36,11 +36,11 @@ struct ConstevalParameters
3636
}
3737
}
3838
}
39-
// ------------------------------- END COMMON -----------------------------------------
39+
// ------------------------------- END COMMON ---------------------------------------------
40+
41+
// -------------------------------- CPP ONLY ----------------------------------------------
4042

41-
// ------------------------------- CPP ONLY -------------------------------------------
4243
#ifndef __HLSL_VERSION
43-
#include <nbl/video/IPhysicalDevice.h>
4444

4545
namespace nbl
4646
{
@@ -51,24 +51,30 @@ namespace workgroup
5151
namespace fft
5252
{
5353

54-
inline std::pair<uint16_t, uint16_t> optimalFFTParameters(const video::ILogicalDevice* device, uint32_t inputArrayLength)
54+
struct OptimalFFTParameters
55+
{
56+
uint16_t elementsPerInvocationLog2;
57+
uint16_t workgroupSizeLog2;
58+
};
59+
60+
inline OptimalFFTParameters optimalFFTParameters(const uint32_t maxWorkgroupSize, uint32_t inputArrayLength)
5561
{
56-
uint32_t maxWorkgroupSize = *device->getPhysicalDevice()->getLimits().maxWorkgroupSize;
5762
// This is the logic found in core::roundUpToPoT to get the log2
58-
uint16_t workgroupSizeLog2 = 1u + hlsl::findMSB(core::min(inputArrayLength / 2, maxWorkgroupSize) - 1u);
59-
uint16_t elementPerInvocationLog2 = 1u + hlsl::findMSB(core::max((inputArrayLength >> workgroupSizeLog2) - 1u, 1u));
60-
return { elementPerInvocationLog2, workgroupSizeLog2 };
63+
const uint16_t workgroupSizeLog2 = 1u + findMSB(min(inputArrayLength / 2, maxWorkgroupSize) - 1u);
64+
const uint16_t elementsPerInvocationLog2 = 1u + findMSB(max((inputArrayLength >> workgroupSizeLog2) - 1u, 1u));
65+
const OptimalFFTParameters retVal = { elementsPerInvocationLog2, workgroupSizeLog2 };
66+
return retVal;
6167
}
6268

6369
}
6470
}
6571
}
6672
}
67-
6873
// ------------------------------- END CPP ONLY -------------------------------------------
6974

7075
// ------------------------------- HLSL ONLY ----------------------------------------------
71-
#else
76+
77+
#else
7278

7379
#include "nbl/builtin/hlsl/subgroup/fft.hlsl"
7480
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"

0 commit comments

Comments
 (0)