Skip to content

Commit cd07b9b

Browse files
committed
Share complex types between cpp and hlsl, add mirror trade functionality for realt FFTs
1 parent 58d8929 commit cd07b9b

File tree

3 files changed

+112
-26
lines changed

3 files changed

+112
-26
lines changed

include/nbl/builtin/hlsl/complex.hlsl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,30 @@
55
#ifndef _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_
66
#define _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_
77

8+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
9+
10+
// -------------------------------------- CPP VERSION ------------------------------------
11+
#ifndef __HLSL_VERSION
12+
13+
#include <complex>
14+
15+
namespace nbl
16+
{
17+
namespace hlsl
18+
{
19+
20+
template<class T>
21+
using complex_t = std::complex<T>;
22+
23+
}
24+
}
25+
26+
// -------------------------------------- END CPP VERSION ------------------------------------
27+
28+
// -------------------------------------- HLSL VERSION ---------------------------------------
29+
#else
30+
831
#include "nbl/builtin/hlsl/functional.hlsl"
9-
#include "nbl/builtin/hlsl/cpp_compat/promote.hlsl"
1032

1133
namespace nbl
1234
{
@@ -409,4 +431,7 @@ NBL_REGISTER_OBJ_TYPE(complex_t<float64_t2>,::nbl::hlsl::alignment_of_v<float64_
409431
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t3>,::nbl::hlsl::alignment_of_v<float64_t3>)
410432
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t4>,::nbl::hlsl::alignment_of_v<float64_t4>)
411433

434+
// -------------------------------------- END HLSL VERSION ---------------------------------------
412435
#endif
436+
437+
#endif

include/nbl/builtin/hlsl/fft/common.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#ifndef _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
22
#define _NBL_BUILTIN_HLSL_FFT_COMMON_INCLUDED_
33

4-
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
4+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5+
#include <nbl/builtin/hlsl/complex.hlsl>
56

67
#ifndef __HLSL_VERSION
78
#include <nbl/core/math/intutil.h>
@@ -29,7 +30,7 @@ static inline uint64_t getOutputBufferSize(const uint32_t3& inputDimensions, uin
2930
{
3031
auto paddedDims = padDimensions(inputDimensions, axes);
3132
uint64_t numberOfComplexElements = paddedDims[0] * paddedDims[1] * paddedDims[2] * numChannels;
32-
return 2 * numberOfComplexElements * (halfFloats ? sizeof(float16_t) : sizeof(float32_t));
33+
return numberOfComplexElements * (halfFloats ? sizeof(complex_t<float16_t>) : sizeof(complex_t<float32_t>));
3334
}
3435

3536

@@ -39,7 +40,6 @@ static inline uint64_t getOutputBufferSize(const uint32_t3& inputDimensions, uin
3940

4041
#else
4142

42-
#include "nbl/builtin/hlsl/complex.hlsl"
4343
#include "nbl/builtin/hlsl/numbers.hlsl"
4444
#include "nbl/builtin/hlsl/concepts.hlsl"
4545

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,43 @@
22
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_INCLUDED_
33

44
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5+
#include <nbl/builtin/hlsl/concepts.hlsl>
56
#include <nbl/builtin/hlsl/fft/common.hlsl>
67

8+
// ------------------------------- COMMON -----------------------------------------
9+
10+
namespace nbl
11+
{
12+
namespace hlsl
13+
{
14+
namespace workgroup
15+
{
16+
namespace fft
17+
{
18+
19+
template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename _Scalar NBL_PRIMARY_REQUIRES(_ElementsPerInvocationLog2 > 0 && _WorkgroupSizeLog2 >= 5)
20+
struct ConstevalParameters
21+
{
22+
using scalar_t = _Scalar;
23+
24+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
25+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
26+
NBL_CONSTEXPR_STATIC_INLINE uint32_t TotalSize = uint32_t(1) << (ElementsPerInvocationLog2 + WorkgroupSizeLog2);
27+
28+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = uint16_t(1) << ElementsPerInvocationLog2;
29+
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2;
30+
31+
// Required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
32+
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) << WorkgroupSizeLog2;
33+
};
34+
35+
}
36+
}
37+
}
38+
}
39+
// ------------------------------- END COMMON -----------------------------------------
40+
41+
// ------------------------------- CPP ONLY -------------------------------------------
742
#ifndef __HLSL_VERSION
843
#include <nbl/video/IPhysicalDevice.h>
944

@@ -30,6 +65,9 @@ inline std::pair<uint16_t, uint16_t> optimalFFTParameters(const video::ILogicalD
3065
}
3166
}
3267

68+
// ------------------------------- END CPP ONLY -------------------------------------------
69+
70+
// ------------------------------- HLSL ONLY ----------------------------------------------
3371
#else
3472

3573
#include "nbl/builtin/hlsl/subgroup/fft.hlsl"
@@ -39,7 +77,6 @@ inline std::pair<uint16_t, uint16_t> optimalFFTParameters(const video::ILogicalD
3977
#include "nbl/builtin/hlsl/mpl.hlsl"
4078
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
4179
#include "nbl/builtin/hlsl/bit.hlsl"
42-
#include "nbl/builtin/hlsl/concepts.hlsl"
4380

4481
// Caveats
4582
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
@@ -157,19 +194,35 @@ struct FFTIndexingUtils
157194
// but also the thread holding said mirror value will at the same time be trying to unpack `NFFT[someOtherIndex]` and need the mirror value of that.
158195
// As long as this unpacking is happening concurrently and in order (meaning the local element index - the higher bits - of `globalElementIndex` and `someOtherIndex` is the
159196
// same) then this function returns both the SubgroupContiguousIndex of the other thread AND the local element index of *the mirror* of `someOtherIndex`
160-
struct NablaMirrorTradeInfo
197+
struct NablaMirrorLocalInfo
161198
{
162199
uint32_t otherThreadID;
163200
uint32_t mirrorLocalIndex;
164201
};
165202

166-
static NablaMirrorTradeInfo getNablaMirrorTradeInfo(uint32_t localElementIndex)
203+
static NablaMirrorLocalInfo getNablaMirrorLocalInfo(uint32_t localElementIndex)
167204
{
168205
const uint32_t globalElementIndex = localElementIndex * WorkgroupSize | workgroup::SubgroupContiguousIndex();
169206
const uint32_t otherElementIndex = FFTIndexingUtils::getNablaMirrorIndex(globalElementIndex);
170207
const uint32_t mirrorLocalIndex = otherElementIndex / WorkgroupSize;
171208
const uint32_t otherThreadID = otherElementIndex & (WorkgroupSize - 1);
172-
NablaMirrorTradeInfo info = { otherThreadID, mirrorLocalIndex };
209+
const NablaMirrorLocalInfo info = { otherThreadID, mirrorLocalIndex };
210+
return info;
211+
}
212+
213+
// Like the above, but return global indices instead.
214+
struct NablaMirrorGlobalInfo
215+
{
216+
uint32_t otherThreadID;
217+
uint32_t mirrorGlobalIndex;
218+
};
219+
220+
static NablaMirrorGlobalInfo getNablaMirrorGlobalInfo(uint32_t globalElementIndex)
221+
{
222+
const uint32_t otherElementIndex = FFTIndexingUtils::getNablaMirrorIndex(globalElementIndex);
223+
const uint32_t mirrorGlobalIndex = glsl::bitfieldInsert<uint32_t>(otherElementIndex, workgroup::SubgroupContiguousIndex(), 0, uint32_t(WorkgroupSizeLog2));
224+
const uint32_t otherThreadID = otherElementIndex & (WorkgroupSize - 1);
225+
const NablaMirrorGlobalInfo info = { otherThreadID, mirrorGlobalIndex };
173226
return info;
174227
}
175228

@@ -178,31 +231,39 @@ struct FFTIndexingUtils
178231
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkgroupSize = uint32_t(1) << WorkgroupSizeLog2;
179232
};
180233

181-
} //namespace fft
182-
183-
// ----------------------------------- End Utils --------------------------------------------------------------
184-
185-
namespace fft
234+
template<uint16_t ElementsPerInvocationLog2, uint16_t WorkgroupSizeLog2>
235+
struct FFTMirrorTradeUtils
186236
{
237+
using indexing_utils_t = FFTIndexingUtils<ElementsPerInvocationLog2, WorkgroupSizeLog2>;
238+
using mirror_info_t = typename indexing_utils_t::NablaMirrorGlobalInfo;
239+
// If trading elements when, for example, unpacking real FFTs, you might do so from within your accessor or from outside.
240+
// If doing so from within your accessor, particularly if using a preloaded accessor, you might want to do this yourself by
241+
// using FFTIndexingUtils::getNablaMirrorTradeInfo and trading the elements yourself (an example of how to set this up is given in
242+
// the FFT Bloom example, in the `fft_mirror_common.hlsl` file).
243+
// If you're doing this from outside your preloaded accessor then you might want to use this method instead.
244+
// Note: you can still pass a preloaded accessor as `arrayAccessor` here, it's just that you're going to be doing extra computations for the indices.
245+
template<typename scalar_t, typename fft_array_accessor_t, typename shared_memory_adaptor_t>
246+
static complex_t<scalar_t> getNablaMirror(uint32_t globalElementIndex, fft_array_accessor_t arrayAccessor, shared_memory_adaptor_t sharedmemAdaptor)
247+
{
248+
const mirror_info_t mirrorInfo = indexing_utils_t::getNablaMirrorGlobalInfo(globalElementIndex);
249+
complex_t<scalar_t> toTrade = arrayAccessor.get(mirrorInfo.mirrorGlobalIndex);
250+
vector<scalar_t, 2> toTradeVector = { toTrade.real(), toTrade.imag() };
251+
workgroup::Shuffle<shared_memory_adaptor_t, vector<scalar_t, 2> >::__call(toTradeVector, mirrorInfo.otherThreadID, sharedmemAdaptor);
252+
toTrade.real(toTradeVector.x);
253+
toTrade.imag(toTradeVector.y);
254+
return toTrade;
255+
}
187256

188-
template<uint16_t _ElementsPerInvocationLog2, uint16_t _WorkgroupSizeLog2, typename _Scalar NBL_PRIMARY_REQUIRES(_ElementsPerInvocationLog2 > 0 && _WorkgroupSizeLog2 >= 5)
189-
struct ConstevalParameters
190-
{
191-
using scalar_t = _Scalar;
257+
NBL_CONSTEXPR_STATIC_INLINE indexing_utils_t IndexingUtils;
258+
};
192259

193-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = _ElementsPerInvocationLog2;
194-
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSizeLog2 = _WorkgroupSizeLog2;
195-
NBL_CONSTEXPR_STATIC_INLINE uint32_t TotalSize = uint32_t(1) << (ElementsPerInvocationLog2 + WorkgroupSizeLog2);
196260

197-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocation = uint16_t(1) << ElementsPerInvocationLog2;
198-
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(1) << WorkgroupSizeLog2;
199261

200-
// Required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
201-
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) << WorkgroupSizeLog2;
202-
};
203262

204263
} //namespace fft
205264

265+
// ----------------------------------- End Utils --------------------------------------------------------------
266+
206267
template<bool Inverse, typename consteval_params_t, class device_capabilities=void>
207268
struct FFT;
208269

@@ -470,7 +531,7 @@ struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSi
470531
}
471532
}
472533

473-
534+
// ------------------------------- END HLSL ONLY ----------------------------------------------
474535
#endif
475536

476537
#endif

0 commit comments

Comments
 (0)