@@ -159,22 +159,20 @@ struct FFT<2,true, Scalar, device_capabilities>
159
159
160
160
// ---------------------------- Below pending --------------------------------------------------
161
161
162
- /*
163
-
164
162
// Forward FFT
165
163
template<uint32_t K, typename Scalar, class device_capabilities>
166
- struct FFT<K,false,device_capabilities>
164
+ struct FFT<K, false , Scalar, device_capabilities>
167
165
{
168
166
template<typename Accessor, typename SharedMemoryAccessor>
169
167
static enable_if_t<mpl::is_pot_v<K>, void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
170
168
{
171
169
static const uint32_t virtualThreadCount = K >> 1 ;
172
170
static const uint16_t passes = mpl::log2<K>::value - 1 ;
173
171
uint32_t stride = K >> 1 ;
174
- [unroll(passes)]
172
+ // [unroll(passes)]
175
173
for (uint16_t pass = 0 ; pass < passes; pass ++)
176
174
{
177
- [unroll(K/2)]
175
+ // [unroll(K/2)]
178
176
for (uint32_t virtualThread = 0 ; virtualThread < virtualThreadCount; virtualThread++)
179
177
{
180
178
const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex ();
@@ -186,7 +184,7 @@ struct FFT<K,false,device_capabilities>
186
184
complex_t<Scalar> lo = accessor.get (loIx * _NBL_HLSL_WORKGROUP_SIZE_);
187
185
complex_t<Scalar> hi = accessor.get (hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
188
186
189
- fft::DIF<Scalar>::radix2(fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
187
+ hlsl:: fft::DIF<Scalar>::radix2 (hlsl:: fft::twiddle<false ,Scalar>(virtualThreadID & (stride - 1 ), stride),lo,hi);
190
188
191
189
accessor.set (loIx, lo);
192
190
accessor.set (hiIx, hi);
@@ -196,8 +194,8 @@ struct FFT<K,false,device_capabilities>
196
194
}
197
195
198
196
// do K/2 small workgroup FFTs
199
- OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
200
- [unroll(K/2)]
197
+ DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
198
+ // [unroll(K/2)]
201
199
for (uint32_t k = 0 ; k < K; k += 2 )
202
200
{
203
201
if (k)
@@ -211,14 +209,14 @@ struct FFT<K,false,device_capabilities>
211
209
212
210
// Inverse FFT
213
211
template<uint32_t K, typename Scalar, class device_capabilities>
214
- struct FFT<K,true,device_capabilities>
212
+ struct FFT<K, true , Scalar, device_capabilities>
215
213
{
216
214
template<typename Accessor, typename SharedMemoryAccessor>
217
215
static enable_if_t<mpl::is_pot_v<K>, void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
218
216
{
219
217
// do K/2 small workgroup FFTs
220
- OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
221
- [unroll(K/2)]
218
+ DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
219
+ // [unroll(K/2)]
222
220
for (uint32_t k = 0 ; k < K; k += 2 )
223
221
{
224
222
if (k)
@@ -231,20 +229,22 @@ struct FFT<K,true,device_capabilities>
231
229
static const uint32_t virtualThreadCount = K >> 1 ;
232
230
static const uint16_t passes = mpl::log2<K>::value - 1 ;
233
231
uint32_t stride = K << 1 ;
234
- [unroll(passes)]
232
+ // [unroll(passes)]
235
233
for (uint16_t pass = 0 ; pass < passes; pass ++)
236
234
{
237
- [unroll(K/2)]
235
+ // [unroll(K/2)]
238
236
for (uint32_t virtualThread = 0 ; virtualThread < virtualThreadCount; virtualThread++)
239
237
{
238
+ const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex ();
239
+
240
240
const uint32_t lsb = virtualThread & (stride - 1 );
241
241
const uint32_t loIx = ((virtualThread ^ lsb) << 1 ) | lsb;
242
242
const uint32_t hiIx = loIx | stride;
243
243
244
244
complex_t<Scalar> lo = accessor.get (loIx * _NBL_HLSL_WORKGROUP_SIZE_);
245
245
complex_t<Scalar> hi = accessor.get (hiIx * _NBL_HLSL_WORKGROUP_SIZE_);
246
246
247
- fft::DIF<Scalar>::radix2(fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
247
+ hlsl:: fft::DIF<Scalar>::radix2 (hlsl:: fft::twiddle<true ,Scalar>(virtualThreadID & (stride - 1 ), stride),lo,hi);
248
248
249
249
accessor.set (loIx, lo);
250
250
accessor.set (hiIx, hi);
@@ -255,8 +255,6 @@ struct FFT<K,true,device_capabilities>
255
255
}
256
256
};
257
257
258
- */
259
-
260
258
}
261
259
}
262
260
}
0 commit comments