Skip to content

Commit af8abff

Browse files
committed
Added some comments explaining what accessors should do
1 parent 983a83b commit af8abff

File tree

1 file changed

+16
-7
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+16
-7
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct exchangeValues<SharedMemoryAdaptor, float16_t>
3232
static void __call(NBL_REF_ARG(complex_t<float16_t>) lo, NBL_REF_ARG(complex_t<float16_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
3333
{
3434
const bool topHalf = bool(threadID & stride);
35-
// Ternary won't take structs so we use this aux variable
35+
// Pack two halves into a single uint32_t
3636
uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real(), lo.imag()) : float16_t2 (hi.real(), hi.imag()));
3737
shuffleXor<SharedMemoryAdaptor, uint32_t>::__call(toExchange, stride, sharedmemAdaptor);
3838
float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
@@ -55,7 +55,7 @@ struct exchangeValues<SharedMemoryAdaptor, float32_t>
5555
static void __call(NBL_REF_ARG(complex_t<float32_t>) lo, NBL_REF_ARG(complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
5656
{
5757
const bool topHalf = bool(threadID & stride);
58-
// Ternary won't take structs so we use this aux variable
58+
// Cast to uint32_t to use the shared memory for shuffling
5959
uint32_t2 toExchange = bit_cast<uint32_t2, float32_t2 >(topHalf ? float32_t2(lo.real(), lo.imag()) : float32_t2(hi.real(), hi.imag()));
6060
shuffleXor<SharedMemoryAdaptor, uint32_t2 >::__call(toExchange, stride, sharedmemAdaptor);
6161
float32_t2 exchanged = bit_cast<float32_t2, uint32_t2 >(toExchange);
@@ -78,7 +78,7 @@ struct exchangeValues<SharedMemoryAdaptor, float64_t>
7878
static void __call(NBL_REF_ARG(complex_t<float64_t>) lo, NBL_REF_ARG(complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
7979
{
8080
const bool topHalf = bool(threadID & stride);
81-
// Ternary won't take structs so we use this aux variable
81+
// Unpack two doubles into four uint32_t
8282
uint32_t4 toExchange = bit_cast<uint32_t4, float64_t2 > (topHalf ? float64_t2(lo.real(), lo.imag()) : float64_t2(hi.real(), hi.imag()));
8383
shuffleXor<SharedMemoryAdaptor, uint32_t4 >::__call(toExchange, stride, sharedmemAdaptor);
8484
float64_t2 exchanged = bit_cast<float64_t2, uint32_t4 >(toExchange);
@@ -104,10 +104,19 @@ struct FFT;
104104

105105
// For the FFT methods below, we assume:
106106
// - Accessor is a global memory accessor to an array fitting 2 * _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
107-
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread.
108-
// There are no assumptions on the data layout: we just require the accessor to provide get and set methods for complex_t<Scalar>.
109-
// - SharedMemoryAccessor accesses a shared memory array that can fit _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, with get and set
110-
// methods for complex_t<Scalar>. It benefits from coalesced accesses
107+
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. The arrays it accesses with `get` and `set` can optionally be
108+
// different, if you don't want the FFT to be done in-place.
109+
// The Accessor MUST provide the following methods:
110+
// * void get(uint32_t index, inout complex_t<Scalar> value);
111+
// * void set(uint32_t index, in complex_t<Scalar> value);
112+
// * void memoryBarrier();
113+
// You might optionally want to provide a `workgroupExecutionAndMemoryBarrier()` method on it to wait on to be sure the whole FFT pass is done
114+
115+
// - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * _NBL_HLSL_WORKGROUP_SIZE_`.
116+
// The SharedMemoryAccessor MUST provide the following methods:
117+
// * void get(uint32_t index, inout uint32_t value);
118+
// * void set(uint32_t index, in uint32_t value);
119+
// * void workgroupExecutionAndMemoryBarrier();
111120

112121
// 2 items per invocation forward specialization
113122
template<typename Scalar, class device_capabilities>

0 commit comments

Comments
 (0)