Skip to content

Commit 975a7b7

Browse files
committed
Fixed accessor concepts for FFT
1 parent f3ad5e8 commit 975a7b7

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

include/nbl/builtin/hlsl/workgroup/fft/accessor_concepts.hlsl renamed to include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_FFT_ACCESSOR_CONCEPTS_INCLUDED_
2-
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_ACCESSOR_CONCEPTS_INCLUDED_
1+
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
33

4-
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5-
#include <nbl/builtin/hlsl/concepts.hlsl>
6-
#include <nbl/builtin/hlsl/fft/common.hlsl>
4+
#include "nbl/builtin/hlsl/concepts.hlsl"
5+
#include "nbl/builtin/hlsl/fft/common.hlsl"
76

87
namespace nbl
98
{
@@ -40,24 +39,23 @@ NBL_CONCEPT_END(
4039

4140

4241

43-
// The Accessor MUST provide a typename `Accessor::scalar_t`
42+
4443
// The Accessor MUST provide the following methods:
4544
// * void get(uint32_t index, inout complex_t<Scalar> value);
4645
// * void set(uint32_t index, in complex_t<Scalar> value);
4746
// * void memoryBarrier();
4847

4948
#define NBL_CONCEPT_NAME FFTAccessor
50-
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
51-
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
49+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
50+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
5251
#define NBL_CONCEPT_PARAM_0 (accessor, T)
5352
#define NBL_CONCEPT_PARAM_1 (index_t, uint32_t)
54-
#define NBL_CONCEPT_PARAM_2 (value_t, complex_t<typename T::scalar_t>)
55-
NBL_CONCEPT_BEGIN(4)
53+
#define NBL_CONCEPT_PARAM_2 (value_t, complex_t<Scalar>)
54+
NBL_CONCEPT_BEGIN(3)
5655
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
5756
#define index_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
5857
#define value_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
5958
NBL_CONCEPT_END(
60-
((NBL_CONCEPT_REQ_TYPE)(T::scalar_t))
6159
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index_t, value_t)), is_same_v, void))
6260
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index_t, value_t)), is_same_v, void))
6361
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.memoryBarrier()), is_same_v, void))

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ inline OptimalFFTParameters optimalFFTParameters(const uint32_t maxWorkgroupSize
8383
#include "nbl/builtin/hlsl/mpl.hlsl"
8484
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
8585
#include "nbl/builtin/hlsl/bit.hlsl"
86-
#include "nbl/builtin/hlsl/workgroup/fft/accessor_concepts.hlsl"
86+
#include "nbl/builtin/hlsl/concepts/accessors/fft.hlsl"
8787

8888
// Caveats
8989
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
@@ -274,14 +274,17 @@ template<bool Inverse, typename consteval_params_t, class device_capabilities=vo
274274
struct FFT;
275275

276276
// For the FFT methods below, we assume:
277-
// - Accessor is a global memory accessor to an array fitting 2 * WorkgroupSize elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
278-
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. The arrays it accesses with `get` and `set` can optionally be
279-
// different, if you don't want the FFT to be done in-place.
280-
// The Accessor MUST provide a typename `Accessor::scalar_t`, and this type MUST be the same as the `Scalar` template parameter of the FFT struct's consteval parameters
277+
// - Accessor is an accessor to an array fitting 2 * WorkgroupSize elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
278+
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. If `ConstevalParameters::ElementsPerInvocationLog2 == 1`,
279+
// the arrays it accesses with `get` and `set` can optionally be different, if you don't want the FFT to be done in-place. Otherwise, you MUST make it in-place
280+
// (this is because if using more than 2 elements per invocation, we use the same array to store intermediate operations).
281281
// The Accessor MUST provide the following methods:
282282
// * void get(uint32_t index, inout complex_t<Scalar> value);
283283
// * void set(uint32_t index, in complex_t<Scalar> value);
284284
// * void memoryBarrier();
285+
// For it to work correctly, this memory barrier must use `AcquireRelease` semantics, with the proper flags set for the memory type.
286+
// If using `ConstevalParameters::ElementsPerInvocationLog2 == 1` or otherwise not needing it (such as when using preloaded accessors) we still require the method to exist
287+
// but you can just make it do nothing.
285288

286289
// - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * WorkgroupSize`.
287290
// The SharedMemoryAccessor MUST provide the following methods:
@@ -305,7 +308,7 @@ struct FFT<false, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device
305308
}
306309

307310

308-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
311+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
309312
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
310313
{
311314
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -371,7 +374,7 @@ struct FFT<true, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_
371374
}
372375

373376

374-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
377+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
375378
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
376379
{
377380
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -432,7 +435,7 @@ struct FFT<false, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupS
432435
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
433436
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
434437

435-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
438+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
436439
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
437440
{
438441
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -481,7 +484,7 @@ struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSi
481484
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
482485
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
483486

484-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
487+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
485488
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
486489
{
487490
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;

src/nbl/builtin/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,6 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/basic.hlsl")
310310
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/ballot.hlsl")
311311
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/broadcast.hlsl")
312312
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/fft.hlsl")
313-
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/fft/accessor_concepts.hlsl")
314313
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/scratch_size.hlsl")
315314
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shared_scan.hlsl")
316315
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shuffle.hlsl")
@@ -331,5 +330,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/anisotropi
331330
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/loadable_image.hlsl")
332331
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/mip_mapped.hlsl")
333332
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/storable_image.hlsl")
333+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/concepts/accessors/fft.hlsl")
334+
334335

335336
ADD_CUSTOM_BUILTIN_RESOURCES(nblBuiltinResourceData NBL_RESOURCES_TO_EMBED "${NBL_ROOT_PATH}/include" "nbl/builtin" "nbl::builtin" "${NBL_ROOT_PATH_BINARY}/include" "${NBL_ROOT_PATH_BINARY}/src" "STATIC" "INTERNAL")

0 commit comments

Comments
 (0)