@@ -16,6 +16,7 @@ namespace workgroup
16
16
{
17
17
namespace fft
18
18
{
19
+
19
20
// ---------------------------------- Utils -----------------------------------------------
20
21
21
22
template<typename SharedMemoryAccessor, typename Scalar>
@@ -166,33 +167,25 @@ struct FFT<K, false, Scalar, device_capabilities>
166
167
template<typename Accessor, typename SharedMemoryAccessor>
167
168
static enable_if_t< (mpl::is_pot_v<K> && K > 2 ), void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
168
169
{
169
- static const uint32_t virtualThreadCount = K >> 1 ;
170
- static const uint16_t passes = mpl::log2<K>::value - 1 ;
171
- uint32_t stride = K >> 1 ;
172
- //[unroll(passes)]
173
- for (uint16_t pass = 0 ; pass < passes; pass ++)
170
+ for (uint32_t stride = (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1 )
174
171
{
175
172
//[unroll(K/2)]
176
- for (uint32_t virtualThread = 0 ; virtualThread < virtualThreadCount; virtualThread++ )
173
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_ )
177
174
{
178
- const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ();
179
-
180
- const uint32_t lsb = virtualThread & (stride - 1 );
181
- const uint32_t loIx = ((virtualThread ^ lsb) << 1 ) | lsb;
175
+ const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
182
176
const uint32_t hiIx = loIx | stride;
183
177
184
- complex_t<Scalar> lo = accessor.get (( loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () );
185
- complex_t<Scalar> hi = accessor.get (( hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () );
178
+ complex_t<Scalar> lo = accessor.get (loIx);
179
+ complex_t<Scalar> hi = accessor.get (hiIx);
186
180
187
- hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false ,Scalar>(virtualThreadID & (( stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1 ), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_ ),lo,hi);
181
+ hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false ,Scalar>(virtualThreadID & (stride - 1 ), stride),lo,hi);
188
182
189
- accessor.set (( loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () , lo);
190
- accessor.set (( hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () , hi);
183
+ accessor.set (loIx, lo);
184
+ accessor.set (hiIx, hi);
191
185
}
192
186
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
193
- stride >>= 1 ;
194
187
}
195
-
188
+
196
189
// do K/2 small workgroup FFTs
197
190
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
198
191
//[unroll(K/2)]
@@ -225,41 +218,34 @@ struct FFT<K, true, Scalar, device_capabilities>
225
218
FFT<2 ,true , Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
226
219
}
227
220
accessor = offsetAccessor.accessor;
228
-
229
- static const uint32_t virtualThreadCount = K >> 1 ;
230
- static const uint16_t passes = mpl::log2<K>::value - 1 ;
231
- uint32_t stride = 2 ;
232
- //[unroll(passes)]
233
- for (uint16_t pass = 0 ; pass < passes; pass ++)
221
+
222
+ for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ << 1 ; stride <= (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1 )
234
223
{
235
224
//[unroll(K/2)]
236
- for (uint32_t virtualThread = 0 ; virtualThread < virtualThreadCount; virtualThread++ )
225
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_ )
237
226
{
238
- const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ();
239
-
240
- const uint32_t lsb = virtualThread & (stride - 1 );
241
- const uint32_t loIx = ((virtualThread ^ lsb) << 1 ) | lsb;
227
+ const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
242
228
const uint32_t hiIx = loIx | stride;
229
+
230
+ complex_t<Scalar> lo = accessor.get (loIx);
231
+ complex_t<Scalar> hi = accessor.get (hiIx);
243
232
244
- complex_t<Scalar> lo = accessor.get ((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ());
245
- complex_t<Scalar> hi = accessor.get ((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ());
246
-
247
- hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true ,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1 ), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
233
+ hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true ,Scalar>(virtualThreadID & (stride - 1 ), stride), lo,hi);
248
234
249
235
// Divide by special factor at the end
250
- if (passes - 1 == pass )
236
+ if ( (K >> 1 ) * _NBL_HLSL_WORKGROUP_SIZE_ == stride )
251
237
{
252
238
divides_assign< complex_t<Scalar> > divAss;
253
- divAss (lo, virtualThreadCount );
254
- divAss (hi, virtualThreadCount );
239
+ divAss (lo, K >> 1 );
240
+ divAss (hi, K >> 1 );
255
241
}
256
242
257
- accessor.set (( loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () , lo);
258
- accessor.set (( hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () , hi);
243
+ accessor.set (loIx, lo);
244
+ accessor.set (hiIx, hi);
259
245
}
260
246
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
261
- stride <<= 1 ;
262
247
}
248
+
263
249
}
264
250
};
265
251
0 commit comments