Skip to content

Commit 415bff4

Browse files
committed
Small math refactor, doing away with workgroup size log 2 requirement
1 parent 51bdd2b commit 415bff4

File tree

1 file changed

+24
-38
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+24
-38
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace workgroup
1616
{
1717
namespace fft
1818
{
19+
1920
// ---------------------------------- Utils -----------------------------------------------
2021

2122
template<typename SharedMemoryAccessor, typename Scalar>
@@ -166,33 +167,25 @@ struct FFT<K, false, Scalar, device_capabilities>
166167
template<typename Accessor, typename SharedMemoryAccessor>
167168
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
168169
{
169-
static const uint32_t virtualThreadCount = K >> 1;
170-
static const uint16_t passes = mpl::log2<K>::value - 1;
171-
uint32_t stride = K >> 1;
172-
//[unroll(passes)]
173-
for (uint16_t pass = 0; pass < passes; pass++)
170+
for (uint32_t stride = (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1)
174171
{
175172
//[unroll(K/2)]
176-
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
173+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
177174
{
178-
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex();
179-
180-
const uint32_t lsb = virtualThread & (stride - 1);
181-
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
175+
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
182176
const uint32_t hiIx = loIx | stride;
183177

184-
complex_t<Scalar> lo = accessor.get((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
185-
complex_t<Scalar> hi = accessor.get((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
178+
complex_t<Scalar> lo = accessor.get(loIx);
179+
complex_t<Scalar> hi = accessor.get(hiIx);
186180

187-
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
181+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
188182

189-
accessor.set((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), lo);
190-
accessor.set((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), hi);
183+
accessor.set(loIx, lo);
184+
accessor.set(hiIx, hi);
191185
}
192186
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
193-
stride >>= 1;
194187
}
195-
188+
196189
// do K/2 small workgroup FFTs
197190
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
198191
//[unroll(K/2)]
@@ -225,41 +218,34 @@ struct FFT<K, true, Scalar, device_capabilities>
225218
FFT<2,true, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
226219
}
227220
accessor = offsetAccessor.accessor;
228-
229-
static const uint32_t virtualThreadCount = K >> 1;
230-
static const uint16_t passes = mpl::log2<K>::value - 1;
231-
uint32_t stride = 2;
232-
//[unroll(passes)]
233-
for (uint16_t pass = 0; pass < passes; pass++)
221+
222+
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ << 1; stride <= (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
234223
{
235224
//[unroll(K/2)]
236-
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
225+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
237226
{
238-
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex();
239-
240-
const uint32_t lsb = virtualThread & (stride - 1);
241-
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
227+
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
242228
const uint32_t hiIx = loIx | stride;
229+
230+
complex_t<Scalar> lo = accessor.get(loIx);
231+
complex_t<Scalar> hi = accessor.get(hiIx);
243232

244-
complex_t<Scalar> lo = accessor.get((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
245-
complex_t<Scalar> hi = accessor.get((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
246-
247-
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
233+
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride), lo,hi);
248234

249235
// Divide by special factor at the end
250-
if (passes - 1 == pass)
236+
if ( (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_ == stride)
251237
{
252238
divides_assign< complex_t<Scalar> > divAss;
253-
divAss(lo, virtualThreadCount);
254-
divAss(hi, virtualThreadCount);
239+
divAss(lo, K >> 1);
240+
divAss(hi, K >> 1);
255241
}
256242

257-
accessor.set((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), lo);
258-
accessor.set((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), hi);
243+
accessor.set(loIx, lo);
244+
accessor.set(hiIx, hi);
259245
}
260246
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
261-
stride <<= 1;
262247
}
248+
263249
}
264250
};
265251

0 commit comments

Comments
 (0)