Skip to content

Commit f85372c

Browse files
committed
Fixed some more twiddle math. Implementing special first and last iterations (shuffle and butterfly)
1 parent 550b02d commit f85372c

File tree

1 file changed

+14
-14
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+14
-14
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ template<uint32_t K, typename Scalar, class device_capabilities>
164164
struct FFT<K, false, Scalar, device_capabilities>
165165
{
166166
template<typename Accessor, typename SharedMemoryAccessor>
167-
static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
167+
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
168168
{
169169
static const uint32_t virtualThreadCount = K >> 1;
170170
static const uint16_t passes = mpl::log2<K>::value - 1;
@@ -175,16 +175,16 @@ struct FFT<K, false, Scalar, device_capabilities>
175175
//[unroll(K/2)]
176176
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
177177
{
178-
const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex();
178+
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex();
179179

180180
const uint32_t lsb = virtualThread & (stride - 1);
181181
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
182182
const uint32_t hiIx = loIx | stride;
183183

184-
complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_);
185-
complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
184+
complex_t<Scalar> lo = accessor.get(loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
185+
complex_t<Scalar> hi = accessor.get(hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
186186

187-
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
187+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
188188

189189
accessor.set(loIx, lo);
190190
accessor.set(hiIx, hi);
@@ -199,7 +199,7 @@ struct FFT<K, false, Scalar, device_capabilities>
199199
for (uint32_t k = 0; k < K; k += 2)
200200
{
201201
if (k)
202-
sharedmemAccessor.executionAndMemoryBarrier();
202+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
203203
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
204204
FFT<2,false, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
205205
}
@@ -212,20 +212,20 @@ template<uint32_t K, typename Scalar, class device_capabilities>
212212
struct FFT<K, true, Scalar, device_capabilities>
213213
{
214214
template<typename Accessor, typename SharedMemoryAccessor>
215-
static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
215+
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
216216
{
217217
// do K/2 small workgroup FFTs
218218
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
219219
//[unroll(K/2)]
220220
for (uint32_t k = 0; k < K; k += 2)
221221
{
222222
if (k)
223-
sharedmemAccessor.executionAndMemoryBarrier();
223+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
224224
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
225225
FFT<2,true, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
226226
}
227227
accessor = offsetAccessor.accessor;
228-
228+
/*
229229
static const uint32_t virtualThreadCount = K >> 1;
230230
static const uint16_t passes = mpl::log2<K>::value - 1;
231231
uint32_t stride = K << 1;
@@ -235,23 +235,23 @@ struct FFT<K, true, Scalar, device_capabilities>
235235
//[unroll(K/2)]
236236
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
237237
{
238-
const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex();
238+
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex();
239239
240240
const uint32_t lsb = virtualThread & (stride - 1);
241241
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
242242
const uint32_t hiIx = loIx | stride;
243243
244-
complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_);
245-
complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
244+
complex_t<Scalar> lo = accessor.get(loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
245+
complex_t<Scalar> hi = accessor.get(hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
246246
247-
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
247+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
248248
249249
accessor.set(loIx, lo);
250250
accessor.set(hiIx, hi);
251251
}
252252
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
253253
stride <<= 1;
254-
}
254+
}*/
255255
}
256256
};
257257

0 commit comments

Comments
 (0)