Skip to content

Commit 83e0cbd

Browse files
committed
- Differentiate concepts for FFT based on ElementsPerInvocationLog2,
this allows the offset accessor passed in the bigger FFTs to pass the concept requirements - Minor change to the offset accessor, I think it follows the original intent
1 parent 2dc70c1 commit 83e0cbd

File tree

3 files changed

+43
-29
lines changed

3 files changed

+43
-29
lines changed

include/nbl/builtin/hlsl/concepts/accessors/fft.hlsl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,45 @@ namespace fft
2121
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
2222
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
2323
#define NBL_CONCEPT_PARAM_0 (accessor, T)
24-
#define NBL_CONCEPT_PARAM_1 (index_t, uint32_t)
25-
#define NBL_CONCEPT_PARAM_2 (value_t, uint32_t)
24+
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
25+
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
2626
NBL_CONCEPT_BEGIN(3)
2727
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
28-
#define index_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
29-
#define value_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
28+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
29+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
3030
NBL_CONCEPT_END(
31-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index_t, value_t)), is_same_v, void))
32-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index_t, value_t)), is_same_v, void))
31+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index, val)), is_same_v, void))
32+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index, val)), is_same_v, void))
3333
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
3434
);
35-
#undef value_t
36-
#undef index_t
35+
#undef val
36+
#undef index
3737
#undef accessor
3838
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
3939

4040

41+
// The Accessor (for a small FFT) MUST provide the following methods:
42+
// * void get(uint32_t index, inout complex_t<Scalar> value);
43+
// * void set(uint32_t index, in complex_t<Scalar> value);
44+
45+
#define NBL_CONCEPT_NAME SmallFFTAccessor
46+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
47+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
48+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
49+
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
50+
#define NBL_CONCEPT_PARAM_2 (val, complex_t<Scalar>)
51+
NBL_CONCEPT_BEGIN(3)
52+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
53+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
54+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
55+
NBL_CONCEPT_END(
56+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index, val)), is_same_v, void))
57+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index, val)), is_same_v, void))
58+
);
59+
#undef val
60+
#undef index
61+
#undef accessor
62+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
4163

4264

4365
// The Accessor MUST provide the following methods:
@@ -49,19 +71,11 @@ NBL_CONCEPT_END(
4971
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
5072
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
5173
#define NBL_CONCEPT_PARAM_0 (accessor, T)
52-
#define NBL_CONCEPT_PARAM_1 (index_t, uint32_t)
53-
#define NBL_CONCEPT_PARAM_2 (value_t, complex_t<Scalar>)
54-
NBL_CONCEPT_BEGIN(3)
74+
NBL_CONCEPT_BEGIN(1)
5575
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
56-
#define index_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
57-
#define value_t NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
5876
NBL_CONCEPT_END(
59-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index_t, value_t)), is_same_v, void))
60-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index_t, value_t)), is_same_v, void))
6177
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.memoryBarrier()), is_same_v, void))
62-
);
63-
#undef value_t
64-
#undef index_t
78+
) && SmallFFTAccessor<T, Scalar>;
6579
#undef accessor
6680
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
6781

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,15 @@ template<class BaseAccessor, typename IndexType=uint32_t, typename _Offset=void>
205205
struct Offset : impl::OffsetBase<IndexType,_Offset>
206206
{
207207
using base_t = impl::OffsetBase<IndexType,_Offset>;
208+
using index_t = IndexType;
208209

209210
BaseAccessor accessor;
210211

211212
template <typename T>
212-
void set(uint32_t idx, T value) {accessor.set(idx+base_t::offset,value); }
213+
void set(index_t idx, T value) {accessor.set(idx+base_t::offset,value); }
213214

214215
template <typename T>
215-
void get(uint32_t idx, NBL_REF_ARG(T) value) {accessor.get(idx+base_t::offset,value);}
216+
void get(index_t idx, NBL_REF_ARG(T) value) {accessor.get(idx+base_t::offset,value);}
216217

217218
template<typename S=BaseAccessor>
218219
enable_if_t<

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -274,19 +274,18 @@ template<bool Inverse, typename consteval_params_t, class device_capabilities=vo
274274
struct FFT;
275275

276276
// For the FFT methods below, we assume:
277-
// - Accessor is an accessor to an array fitting 2 * WorkgroupSize elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
278-
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. If `ConstevalParameters::ElementsPerInvocationLog2 == 1`,
279-
// the arrays it accesses with `get` and `set` can optionally be different, if you don't want the FFT to be done in-place. Otherwise, you MUST make it in-place
280-
// (this is because if using more than 2 elements per invocation, we use the same array to store intermediate operations).
277+
// - Accessor is an accessor to an array fitting ElementsPerInvocation * WorkgroupSize elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT.
278+
// If `ConstevalParameters::ElementsPerInvocationLog2 == 1`, the arrays it accesses with `get` and `set` can optionally be different,
279+
// if you don't want the FFT to be done in-place. Otherwise, you MUST make it in-place.
281280
// The Accessor MUST provide the following methods:
282281
// * void get(uint32_t index, inout complex_t<Scalar> value);
283282
// * void set(uint32_t index, in complex_t<Scalar> value);
284283
// * void memoryBarrier();
285284
// For it to work correctly, this memory barrier must use `AcquireRelease` semantics, with the proper flags set for the memory type.
286-
// If using `ConstevalParameters::ElementsPerInvocationLog2 == 1` or otherwise not needing it (such as when using preloaded accessors) we still require the method to exist
287-
// but you can just make it do nothing.
285+
// If using `ConstevalParameters::ElementsPerInvocationLog2 == 1` the Accessor IS ALLOWED TO not provide this last method.
286+
// If not needing it (such as when using preloaded accessors) we still require the method to exist but you can just make it do nothing.
288287

289-
// - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * WorkgroupSize`.
288+
// - SharedMemoryAccessor accesses a workgroup-shared memory array of size `WorkgroupSize` elements of type complex_t<Scalar>.
290289
// The SharedMemoryAccessor MUST provide the following methods:
291290
// * void get(uint32_t index, inout uint32_t value);
292291
// * void set(uint32_t index, in uint32_t value);
@@ -308,7 +307,7 @@ struct FFT<false, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device
308307
}
309308

310309

311-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
310+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::SmallFFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
312311
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
313312
{
314313
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;
@@ -374,7 +373,7 @@ struct FFT<true, fft::ConstevalParameters<1, WorkgroupSizeLog2, Scalar>, device_
374373
}
375374

376375

377-
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::FFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
376+
template<typename Accessor, typename SharedMemoryAccessor NBL_FUNC_REQUIRES(fft::SmallFFTAccessor<Accessor, Scalar> && fft::FFTSharedMemoryAccessor<SharedMemoryAccessor>)
378377
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
379378
{
380379
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = consteval_params_t::WorkgroupSize;

0 commit comments

Comments
 (0)