Skip to content

Commit b20a8bc

Browse files
committed
More-than-2-per-thread FFT implemented, not yet tested
1 parent 1b7582a commit b20a8bc

File tree

4 files changed

+131
-10
lines changed

4 files changed

+131
-10
lines changed

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,36 @@ struct MemoryAdaptor<BaseAccessor, 0>
257257
}
258258
};
259259

260+
// ---------------------------------------------- Offset Accessor ----------------------------------------------------
261+
262+
template<class BaseAccessor, class AccessorType, uint32_t Offset>
263+
struct OffsetAccessor
264+
{
265+
BaseAccessor accessor;
266+
267+
void set(uint32_t idx, NBL_REF_ARG(AccessorType) x) {accessor.set(idx + Offset, x);}
268+
269+
AccessorType get(uint32_t idx) {return accessor.get(idx + Offset);}
270+
271+
// TODO: figure out the `enable_if` syntax for this
272+
void workgroupExecutionAndMemoryBarrier() {accessor.workgroupExecutionAndMemoryBarrier();}
273+
};
274+
275+
// Dynamic offset version
276+
template<class BaseAccessor, class AccessorType>
277+
struct DynamicOffsetAccessor
278+
{
279+
BaseAccessor accessor;
280+
uint32_t offset;
281+
282+
void set(uint32_t idx, NBL_REF_ARG(AccessorType) x) {accessor.set(idx + offset, x);}
283+
284+
AccessorType get(uint32_t idx) {return accessor.get(idx + offset);}
285+
286+
// TODO: figure out the `enable_if` syntax for this
287+
void workgroupExecutionAndMemoryBarrier() {accessor.workgroupExecutionAndMemoryBarrier();}
288+
};
289+
260290
}
261291
}
262292

include/nbl/builtin/hlsl/mpl.hlsl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ struct rotr
6363
static const T value = (S >= 0) ? ((X >> r) | (X << (N - r))) : (X << (-r)) | (X >> (N - (-r)));
6464
};
6565

66+
template<uint64_t N>
67+
struct is_pot : bool_constant< (N > 0 && !(N & (N - 1))) > {};
68+
69+
template<uint64_t N>
70+
NBL_CONSTEXPR_STATIC_INLINE bool is_pot_v = is_pot<N>::value;
6671

6772
}
6873
}

include/nbl/builtin/hlsl/type_traits.hlsl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,9 @@ using is_unbounded_array = std::is_unbounded_array<T>;
526526
template<class T>
527527
using is_scalar = std::is_scalar<T>;
528528

529+
template<class T>
530+
NBL_CONSTEXPR_STATIC_INLINE bool is_scalar_v = is_scalar<T>::value;
531+
529532
template<class T>
530533
struct is_signed : impl::base_type_forwarder<std::is_signed, T> {};
531534

@@ -535,6 +538,9 @@ struct is_unsigned : impl::base_type_forwarder<std::is_unsigned, T> {};
535538
template<class T>
536539
struct is_integral : impl::base_type_forwarder<std::is_integral, T> {};
537540

541+
template<class T>
542+
NBL_CONSTEXPR_STATIC_INLINE bool is_integral_v = is_integral<T>::value;
543+
538544
template<class T>
539545
struct is_floating_point : impl::base_type_forwarder<std::is_floating_point, T> {};
540546

@@ -583,6 +589,9 @@ using extent = std::extent<T, I>;
583589
template<bool B, class T = void>
584590
using enable_if = std::enable_if<B, T>;
585591

592+
template<bool B, class T = void>
593+
using enable_if_t = typename enable_if<B, T>::type;
594+
586595
template<class T>
587596
using alignment_of = std::alignment_of<T>;
588597

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
66
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
77
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
8+
#include "nbl/builtin/hlsl/mpl.hlsl"
9+
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
810

911
namespace nbl
1012
{
@@ -159,21 +161,96 @@ struct FFT<2,true, Scalar, device_capabilities>
159161

160162
/*
161163
162-
// then define 4,8,16 in terms of calling the FFT<2> and doing the special radix steps before/after
163-
template<uint16_t K, bool Inverse, class device_capabilities>
164-
struct FFT
164+
// Forward FFT
165+
template<uint32_t K, typename Scalar, class device_capabilities>
166+
struct FFT<K,false,device_capabilities>
165167
{
166-
template<typename Accessor, typename ShaderMemoryAccessor>
167-
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(ShaderMemoryAccessor) sharedmemAccessor)
168+
template<typename Accessor, typename SharedMemoryAccessor>
169+
static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
170+
{
171+
static const uint32_t virtualThreadCount = K >> 1;
172+
static const uint16_t passes = mpl::log2<K>::value - 1;
173+
uint32_t stride = K >> 1;
174+
[unroll(passes)]
175+
for (uint16_t pass = 0; pass < passes; pass++)
176+
{
177+
[unroll(K/2)]
178+
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
179+
{
180+
const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex();
181+
182+
const uint32_t lsb = virtualThread & (stride - 1);
183+
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
184+
const uint32_t hiIx = loIx | stride;
185+
186+
complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_);
187+
complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
188+
189+
fft::DIF<Scalar>::radix2(fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
190+
191+
accessor.set(loIx, lo);
192+
accessor.set(hiIx, hi);
193+
}
194+
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
195+
stride >>= 1;
196+
}
197+
198+
// do K/2 small workgroup FFTs
199+
OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
200+
[unroll(K/2)]
201+
for (uint32_t k = 0; k < K; k += 2)
202+
{
203+
if (k)
204+
sharedmemAccessor.executionAndMemoryBarrier();
205+
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
206+
FFT<2,false, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
207+
}
208+
accessor = offsetAccessor.accessor;
209+
}
210+
};
211+
212+
// Inverse FFT
213+
template<uint32_t K, typename Scalar, class device_capabilities>
214+
struct FFT<K,true,device_capabilities>
215+
{
216+
template<typename Accessor, typename SharedMemoryAccessor>
217+
static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
168218
{
169-
if (!Inverse)
219+
// do K/2 small workgroup FFTs
220+
OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
221+
[unroll(K/2)]
222+
for (uint32_t k = 0; k < K; k += 2)
170223
{
171-
... special steps ...
224+
if (k)
225+
sharedmemAccessor.executionAndMemoryBarrier();
226+
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
227+
FFT<2,true, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
172228
}
173-
FFT<2,Inverse,device_capabilities>::template __call<Accessor,SharedMemoryAccessor>(access,sharedMemAccessor);
174-
if (Inverse)
229+
accessor = offsetAccessor.accessor;
230+
231+
static const uint32_t virtualThreadCount = K >> 1;
232+
static const uint16_t passes = mpl::log2<K>::value - 1;
233+
uint32_t stride = K << 1;
234+
[unroll(passes)]
235+
for (uint16_t pass = 0; pass < passes; pass++)
175236
{
176-
... special steps ...
237+
[unroll(K/2)]
238+
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
239+
{
240+
const uint32_t lsb = virtualThread & (stride - 1);
241+
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
242+
const uint32_t hiIx = loIx | stride;
243+
244+
complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_);
245+
complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
246+
247+
fft::DIF<Scalar>::radix2(fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
248+
249+
accessor.set(loIx, lo);
250+
accessor.set(hiIx, hi);
251+
}
252+
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
253+
stride <<= 1;
177254
}
178255
}
179256
};

0 commit comments

Comments
 (0)