Skip to content

Commit 550b02d

Browse files
committed
Fixing twiddle math for inverse FFT
1 parent b20a8bc commit 550b02d

File tree

1 file changed

+14
-16
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+14
-16
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -159,22 +159,20 @@ struct FFT<2,true, Scalar, device_capabilities>
159159

160160
// ---------------------------- Below pending --------------------------------------------------
161161

162-
/*
163-
164162
// Forward FFT
165163
template<uint32_t K, typename Scalar, class device_capabilities>
166-
struct FFT<K,false,device_capabilities>
164+
struct FFT<K, false, Scalar, device_capabilities>
167165
{
168166
template<typename Accessor, typename SharedMemoryAccessor>
169167
static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
170168
{
171169
static const uint32_t virtualThreadCount = K >> 1;
172170
static const uint16_t passes = mpl::log2<K>::value - 1;
173171
uint32_t stride = K >> 1;
174-
[unroll(passes)]
172+
//[unroll(passes)]
175173
for (uint16_t pass = 0; pass < passes; pass++)
176174
{
177-
[unroll(K/2)]
175+
//[unroll(K/2)]
178176
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
179177
{
180178
const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex();
@@ -186,7 +184,7 @@ struct FFT<K,false,device_capabilities>
186184
complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_);
187185
complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
188186

189-
fft::DIF<Scalar>::radix2(fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
187+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
190188

191189
accessor.set(loIx, lo);
192190
accessor.set(hiIx, hi);
@@ -196,8 +194,8 @@ struct FFT<K,false,device_capabilities>
196194
}
197195

198196
// do K/2 small workgroup FFTs
199-
OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
200-
[unroll(K/2)]
197+
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
198+
//[unroll(K/2)]
201199
for (uint32_t k = 0; k < K; k += 2)
202200
{
203201
if (k)
@@ -211,14 +209,14 @@ struct FFT<K,false,device_capabilities>
211209

212210
// Inverse FFT
213211
template<uint32_t K, typename Scalar, class device_capabilities>
214-
struct FFT<K,true,device_capabilities>
212+
struct FFT<K, true, Scalar, device_capabilities>
215213
{
216214
template<typename Accessor, typename SharedMemoryAccessor>
217215
static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
218216
{
219217
// do K/2 small workgroup FFTs
220-
OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
221-
[unroll(K/2)]
218+
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
219+
//[unroll(K/2)]
222220
for (uint32_t k = 0; k < K; k += 2)
223221
{
224222
if (k)
@@ -231,20 +229,22 @@ struct FFT<K,true,device_capabilities>
231229
static const uint32_t virtualThreadCount = K >> 1;
232230
static const uint16_t passes = mpl::log2<K>::value - 1;
233231
uint32_t stride = K << 1;
234-
[unroll(passes)]
232+
//[unroll(passes)]
235233
for (uint16_t pass = 0; pass < passes; pass++)
236234
{
237-
[unroll(K/2)]
235+
//[unroll(K/2)]
238236
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
239237
{
238+
const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex();
239+
240240
const uint32_t lsb = virtualThread & (stride - 1);
241241
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
242242
const uint32_t hiIx = loIx | stride;
243243

244244
complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_);
245245
complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
246246

247-
fft::DIF<Scalar>::radix2(fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
247+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
248248

249249
accessor.set(loIx, lo);
250250
accessor.set(hiIx, hi);
@@ -255,8 +255,6 @@ struct FFT<K,true,device_capabilities>
255255
}
256256
};
257257

258-
*/
259-
260258
}
261259
}
262260
}

0 commit comments

Comments
 (0)