@@ -66,7 +66,7 @@ struct FFT<2,false, Scalar, device_capabilities>
66
66
// Compute the indices only once
67
67
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
68
68
const uint32_t loIx = threadID;
69
- const uint32_t hiIx = loIx + _NBL_HLSL_WORKGROUP_SIZE_ ;
69
+ const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx ;
70
70
71
71
// Read lo, hi values from global memory
72
72
complex_t<Scalar> lo = accessor.get (loIx);
@@ -119,8 +119,8 @@ struct FFT<2,true, Scalar, device_capabilities>
119
119
{
120
120
// Compute the indices only once
121
121
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
122
- const uint32_t loIx = (glsl:: gl_SubgroupID ()<<(glsl:: gl_SubgroupSizeLog2 ()+ 1 ))+glsl:: gl_SubgroupInvocationID () ;
123
- const uint32_t hiIx = loIx+glsl:: gl_SubgroupSize () ;
122
+ const uint32_t loIx = threadID ;
123
+ const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
124
124
125
125
// Read lo, hi values from global memory
126
126
complex_t<Scalar> lo = accessor.get (loIx);
@@ -175,19 +175,19 @@ struct FFT<K, false, Scalar, device_capabilities>
175
175
//[unroll(K/2)]
176
176
for (uint32_t virtualThread = 0 ; virtualThread < virtualThreadCount; virtualThread++)
177
177
{
178
- const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex ();
178
+ const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ();
179
179
180
180
const uint32_t lsb = virtualThread & (stride - 1 );
181
181
const uint32_t loIx = ((virtualThread ^ lsb) << 1 ) | lsb;
182
182
const uint32_t hiIx = loIx | stride;
183
183
184
- complex_t<Scalar> lo = accessor.get (loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
185
- complex_t<Scalar> hi = accessor.get (hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
184
+ complex_t<Scalar> lo = accessor.get (( loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () );
185
+ complex_t<Scalar> hi = accessor.get (( hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () );
186
186
187
187
hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false ,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1 ), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
188
188
189
- accessor.set (loIx, lo);
190
- accessor.set (hiIx, hi);
189
+ accessor.set (( loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () , lo);
190
+ accessor.set (( hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex () , hi);
191
191
}
192
192
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
193
193
stride >>= 1 ;
@@ -225,33 +225,41 @@ struct FFT<K, true, Scalar, device_capabilities>
225
225
FFT<2 ,true , Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
226
226
}
227
227
accessor = offsetAccessor.accessor;
228
- /*
228
+
229
229
static const uint32_t virtualThreadCount = K >> 1 ;
230
230
static const uint16_t passes = mpl::log2<K>::value - 1 ;
231
- uint32_t stride = K << 1 ;
231
+ uint32_t stride = 2 ;
232
232
//[unroll(passes)]
233
233
for (uint16_t pass = 0 ; pass < passes; pass ++)
234
234
{
235
235
//[unroll(K/2)]
236
236
for (uint32_t virtualThread = 0 ; virtualThread < virtualThreadCount; virtualThread++)
237
237
{
238
- const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex();
238
+ const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ();
239
239
240
240
const uint32_t lsb = virtualThread & (stride - 1 );
241
241
const uint32_t loIx = ((virtualThread ^ lsb) << 1 ) | lsb;
242
242
const uint32_t hiIx = loIx | stride;
243
+
244
+ complex_t<Scalar> lo = accessor.get ((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ());
245
+ complex_t<Scalar> hi = accessor.get ((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex ());
246
+
247
+ hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true ,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1 ), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
243
248
244
- complex_t<Scalar> lo = accessor.get(loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
245
- complex_t<Scalar> hi = accessor.get(hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
246
-
247
- hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
248
-
249
- accessor.set(loIx, lo);
250
- accessor.set(hiIx, hi);
249
+ // Divide by special factor at the end
250
+ if (passes - 1 == pass )
251
+ {
252
+ divides_assign< complex_t<Scalar> > divAss;
253
+ divAss (lo, virtualThreadCount);
254
+ divAss (hi, virtualThreadCount);
255
+ }
256
+
257
+ accessor.set ((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex (), lo);
258
+ accessor.set ((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex (), hi);
251
259
}
252
260
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
253
261
stride <<= 1 ;
254
- }*/
262
+ }
255
263
}
256
264
};
257
265
0 commit comments