Skip to content

Commit 250e304

Browse files
committed
Naive unrolling added to both workgroup and subgroup FFT - should get the precise number of loops from device capabilities
1 parent 94a19eb commit 250e304

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

include/nbl/builtin/hlsl/subgroup/fft.hlsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ struct FFT<false, Scalar, device_capabilities>
5151
fft::DIF<Scalar>::radix2(fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
5252

5353
// Decimation in Frequency
54+
[unroll]
5455
for (uint32_t stride = subgroupSize >> 1; stride > 0; stride >>= 1)
5556
FFT_loop(stride, lo, hi);
5657
}
@@ -88,6 +89,7 @@ struct FFT<true, Scalar, device_capabilities>
8889
const uint32_t doubleSubgroupSize = subgroupSize << 1; //This is N
8990

9091
// Decimation in Time
92+
[unroll]
9193
for (uint32_t stride = 1; stride < subgroupSize; stride <<= 1)
9294
FFT_loop(stride, lo, hi);
9395

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ struct FFT<2,false, Scalar, device_capabilities>
146146
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
147147

148148
// Run bigger steps until Subgroup-sized
149+
[unroll]
149150
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1; stride > glsl::gl_SubgroupSize(); stride >>= 1)
150151
{
151152
FFT_loop< MemoryAdaptor<SharedMemoryAccessor> >(stride, lo, hi, threadID, sharedmemAdaptor);
@@ -211,6 +212,7 @@ struct FFT<2,true, Scalar, device_capabilities>
211212
fft::exchangeValues<MemoryAdaptor<SharedMemoryAccessor>, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
212213

213214
// The bigger steps
215+
[unroll]
214216
for (uint32_t stride = glsl::gl_SubgroupSize() << 1; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
215217
{
216218
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
@@ -241,9 +243,10 @@ struct FFT<K, false, Scalar, device_capabilities>
241243
template<typename Accessor, typename SharedMemoryAccessor>
242244
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
243245
{
246+
[unroll]
244247
for (uint32_t stride = (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1)
245248
{
246-
//[unroll(K/2)]
249+
[unroll]
247250
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
248251
{
249252
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
@@ -263,7 +266,7 @@ struct FFT<K, false, Scalar, device_capabilities>
263266

264267
// do K/2 small workgroup FFTs
265268
DynamicOffsetAccessor <Accessor> offsetAccessor;
266-
//[unroll(K/2)]
269+
[unroll]
267270
for (uint32_t k = 0; k < K; k += 2)
268271
{
269272
if (k)
@@ -284,7 +287,7 @@ struct FFT<K, true, Scalar, device_capabilities>
284287
{
285288
// do K/2 small workgroup FFTs
286289
DynamicOffsetAccessor <Accessor> offsetAccessor;
287-
//[unroll(K/2)]
290+
[unroll]
288291
for (uint32_t k = 0; k < K; k += 2)
289292
{
290293
if (k)
@@ -293,11 +296,12 @@ struct FFT<K, true, Scalar, device_capabilities>
293296
FFT<2,true, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
294297
}
295298
accessor = offsetAccessor.accessor;
296-
299+
300+
[unroll]
297301
for (uint32_t stride = 2 * _NBL_HLSL_WORKGROUP_SIZE_; stride < K * _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
298302
{
299303
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
300-
//[unroll(K/2)]
304+
[unroll]
301305
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
302306
{
303307
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));

0 commit comments

Comments
 (0)