@@ -164,7 +164,7 @@ template<uint32_t K, typename Scalar, class device_capabilities>
164
164
struct FFT<K, false , Scalar, device_capabilities>
165
165
{
166
166
template<typename Accessor, typename SharedMemoryAccessor>
167
- static enable_if_t<mpl::is_pot_v<K>, void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
167
+ static enable_if_t< ( mpl::is_pot_v<K> && K > 2 ) , void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
168
168
{
169
169
static const uint32_t virtualThreadCount = K >> 1 ;
170
170
static const uint16_t passes = mpl::log2<K>::value - 1 ;
@@ -175,16 +175,16 @@ struct FFT<K, false, Scalar, device_capabilities>
175
175
//[unroll(K/2)]
176
176
for (uint32_t virtualThread = 0 ; virtualThread < virtualThreadCount; virtualThread++)
177
177
{
178
- const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex ();
178
+ const uint32_t virtualThreadID = ( virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex ();
179
179
180
180
const uint32_t lsb = virtualThread & (stride - 1 );
181
181
const uint32_t loIx = ((virtualThread ^ lsb) << 1 ) | lsb;
182
182
const uint32_t hiIx = loIx | stride;
183
183
184
- complex_t<Scalar> lo = accessor.get (loIx * _NBL_HLSL_WORKGROUP_SIZE_ );
185
- complex_t<Scalar> hi = accessor.get (hiIx * _NBL_HLSL_WORKGROUP_SIZE_ );
184
+ complex_t<Scalar> lo = accessor.get (loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_ );
185
+ complex_t<Scalar> hi = accessor.get (hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_ );
186
186
187
- hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false ,Scalar>(virtualThreadID & (stride - 1 ), stride),lo,hi);
187
+ hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false ,Scalar>(virtualThreadID & (( stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1 ), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_ ),lo,hi);
188
188
189
189
accessor.set (loIx, lo);
190
190
accessor.set (hiIx, hi);
@@ -199,7 +199,7 @@ struct FFT<K, false, Scalar, device_capabilities>
199
199
for (uint32_t k = 0 ; k < K; k += 2 )
200
200
{
201
201
if (k)
202
- sharedmemAccessor.executionAndMemoryBarrier ();
202
+ sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
203
203
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
204
204
FFT<2 ,false , Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
205
205
}
@@ -212,20 +212,20 @@ template<uint32_t K, typename Scalar, class device_capabilities>
212
212
struct FFT<K, true , Scalar, device_capabilities>
213
213
{
214
214
template<typename Accessor, typename SharedMemoryAccessor>
215
- static enable_if_t<mpl::is_pot_v<K>, void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
215
+ static enable_if_t< ( mpl::is_pot_v<K> && K > 2 ) , void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
216
216
{
217
217
// do K/2 small workgroup FFTs
218
218
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
219
219
//[unroll(K/2)]
220
220
for (uint32_t k = 0 ; k < K; k += 2 )
221
221
{
222
222
if (k)
223
- sharedmemAccessor.executionAndMemoryBarrier ();
223
+ sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
224
224
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
225
225
FFT<2 ,true , Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
226
226
}
227
227
accessor = offsetAccessor.accessor;
228
-
228
+ /*
229
229
static const uint32_t virtualThreadCount = K >> 1;
230
230
static const uint16_t passes = mpl::log2<K>::value - 1;
231
231
uint32_t stride = K << 1;
@@ -235,23 +235,23 @@ struct FFT<K, true, Scalar, device_capabilities>
235
235
//[unroll(K/2)]
236
236
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
237
237
{
238
- const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex ();
238
+ const uint32_t virtualThreadID = ( virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex();
239
239
240
240
const uint32_t lsb = virtualThread & (stride - 1);
241
241
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
242
242
const uint32_t hiIx = loIx | stride;
243
243
244
- complex_t<Scalar> lo = accessor.get (loIx * _NBL_HLSL_WORKGROUP_SIZE_ );
245
- complex_t<Scalar> hi = accessor.get (hiIx * _NBL_HLSL_WORKGROUP_SIZE_ );
244
+ complex_t<Scalar> lo = accessor.get(loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_ );
245
+ complex_t<Scalar> hi = accessor.get(hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_ );
246
246
247
- hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<true ,Scalar>(virtualThreadID & (stride - 1 ), stride),lo,hi);
247
+ hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (( stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_ ),lo,hi);
248
248
249
249
accessor.set(loIx, lo);
250
250
accessor.set(hiIx, hi);
251
251
}
252
252
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
253
253
stride <<= 1;
254
- }
254
+ }*/
255
255
}
256
256
};
257
257
0 commit comments