Skip to content

Commit 9abf1de

Browse files
committed
Require concepts for Accessors for FFT
1 parent 20b4e3a commit 9abf1de

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ inline OptimalFFTParameters optimalFFTParameters(const uint32_t maxWorkgroupSize
8383
#include "nbl/builtin/hlsl/mpl.hlsl"
8484
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
8585
#include "nbl/builtin/hlsl/bit.hlsl"
86+
#include "nbl/builtin/hlsl/workgroup/fft/accessor_concepts.hlsl"
8687

8788
// Caveats
8889
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
@@ -276,11 +277,11 @@ struct FFT;
276277
// - Accessor is a global memory accessor to an array fitting 2 * WorkgroupSize elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
277278
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. The arrays it accesses with `get` and `set` can optionally be
278279
// different, if you don't want the FFT to be done in-place.
280+
// The Accessor MUST provide a typename `Accessor::scalar_t`, and this type MUST be the same as the `Scalar` template parameter of the FFT struct's consteval parameters
279281
// The Accessor MUST provide the following methods:
280282
// * void get(uint32_t index, inout complex_t<Scalar> value);
281283
// * void set(uint32_t index, in complex_t<Scalar> value);
282284
// * void memoryBarrier();
283-
// You might optionally want to provide a `workgroupExecutionAndMemoryBarrier()` method on it to wait on to be sure the whole FFT pass is done
284285

285286
// - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * WorkgroupSize`.
286287
// The SharedMemoryAccessor MUST provide the following methods:
@@ -304,7 +305,7 @@ struct FFT<false, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device
304305
}
305306

306307

307-
template<typename Accessor, typename SharedMemoryAccessor>
308+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
308309
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
309310
{
310311
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -370,7 +371,7 @@ struct FFT<true, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_
370371
}
371372

372373

373-
template<typename Accessor, typename SharedMemoryAccessor>
374+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
374375
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
375376
{
376377
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -431,7 +432,7 @@ struct FFT<false, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupS
431432
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
432433
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
433434

434-
template<typename Accessor, typename SharedMemoryAccessor>
435+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
435436
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
436437
{
437438
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -480,7 +481,7 @@ struct FFT<true, fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSi
480481
using consteval_params_t = fft::ConstevalParameters<ElementsPerInvocationLog2, WorkgroupSizeLog2, Scalar>;
481482
using small_fft_consteval_params_t = fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>;
482483

483-
template<typename Accessor, typename SharedMemoryAccessor>
484+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
484485
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
485486
{
486487
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_FFT_ACCESSOR_CONCEPTS_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_WORKGROUP_FFT_ACCESSOR_CONCEPTS_INCLUDED_
3+
4+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
5+
#include <nbl/builtin/hlsl/concepts.hlsl>
6+
#include <nbl/builtin/hlsl/fft/common.hlsl>
7+
8+
namespace nbl
9+
{
10+
namespace hlsl
11+
{
12+
namespace workgroup
13+
{
14+
namespace fft
15+
{
16+
// The SharedMemoryAccessor MUST provide the following methods:
17+
// * void get(uint32_t index, inout uint32_t value);
18+
// * void set(uint32_t index, in uint32_t value);
19+
// * void workgroupExecutionAndMemoryBarrier();
20+
21+
#define NBL_CONCEPT_NAME FFTSharedMemoryAccessor
22+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
23+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
24+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
25+
#define NBL_CONCEPT_PARAM_1 (index_t, uint32_t)
26+
#define NBL_CONCEPT_PARAM_2 (value_t, uint32_t)
27+
NBL_CONCEPT_BEGIN(3)
28+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
29+
#define index_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
30+
#define value_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
31+
NBL_CONCEPT_END(
32+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index_t, value_t)), is_same_v, void))
33+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index_t, value_t)), is_same_v, void))
34+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
35+
);
36+
#undef value_t
37+
#undef index_t
38+
#undef accessor
39+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
40+
41+
42+
43+
// The Accessor MUST provide a typename `Accessor::scalar_t`
44+
// The Accessor MUST provide the following methods:
45+
// * void get(uint32_t index, inout complex_t<Scalar> value);
46+
// * void set(uint32_t index, in complex_t<Scalar> value);
47+
// * void memoryBarrier();
48+
49+
#define NBL_CONCEPT_NAME FFTAccessor
50+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
51+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
52+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
53+
#define NBL_CONCEPT_PARAM_1 (index_t, uint32_t)
54+
#define NBL_CONCEPT_PARAM_2 (value_t, complex_t<typename T::scalar_t>)
55+
NBL_CONCEPT_BEGIN(4)
56+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
57+
#define index_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
58+
#define value_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
59+
NBL_CONCEPT_END(
60+
((NBL_CONCEPT_REQ_TYPE)(T::scalar_t))
61+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index_t, value_t)), is_same_v, void))
62+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index_t, value_t)), is_same_v, void))
63+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.memoryBarrier()), is_same_v, void))
64+
);
65+
#undef value_t
66+
#undef index_t
67+
#undef accessor
68+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
69+
70+
}
71+
}
72+
}
73+
}
74+
75+
#endif

src/nbl/builtin/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/basic.hlsl")
310310
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/ballot.hlsl")
311311
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/broadcast.hlsl")
312312
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/fft.hlsl")
313+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/fft/accessor_concepts.hlsl")
313314
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/scratch_size.hlsl")
314315
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shared_scan.hlsl")
315316
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shuffle.hlsl")

0 commit comments

Comments
 (0)