@@ -93,15 +93,19 @@ struct exchangeValues<SharedMemoryAdaptor, float64_t>
93
93
}
94
94
};
95
95
96
+ // Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
97
+ template <typename scalar_t, uint32_t WorkgroupSize>
98
+ NBL_CONSTEXPR uint32_t sharedMemSize = 2 * WorkgroupSize * (sizeof (scalar_t) / sizeof (uint32_t));
99
+
96
100
} //namespace fft
97
101
98
102
// ----------------------------------- End Utils -----------------------------------------------
99
103
100
- template<uint16_t ElementsPerInvocation, bool Inverse, typename Scalar, class device_capabilities=void >
104
+ template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void >
101
105
struct FFT;
102
106
103
107
// For the FFT methods below, we assume:
104
- // - Accessor is a global memory accessor to an array fitting 2 * _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
108
+ // - Accessor is a global memory accessor to an array fitting 2 * WorkgroupSize elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
105
109
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. The arrays it accesses with `get` and `set` can optionally be
106
110
// different, if you don't want the FFT to be done in-place.
107
111
// The Accessor MUST provide the following methods:
@@ -110,15 +114,15 @@ struct FFT;
110
114
// * void memoryBarrier();
111
115
// You might optionally want to provide a `workgroupExecutionAndMemoryBarrier()` method on it to wait on to be sure the whole FFT pass is done
112
116
113
- // - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * _NBL_HLSL_WORKGROUP_SIZE_ `.
117
+ // - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * WorkgroupSize `.
114
118
// The SharedMemoryAccessor MUST provide the following methods:
115
119
// * void get(uint32_t index, inout uint32_t value);
116
120
// * void set(uint32_t index, in uint32_t value);
117
121
// * void workgroupExecutionAndMemoryBarrier();
118
122
119
123
// 2 items per invocation forward specialization
120
- template<typename Scalar, class device_capabilities>
121
- struct FFT<2 ,false , Scalar, device_capabilities>
124
+ template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
125
+ struct FFT<2 ,false , WorkgroupSize, Scalar, device_capabilities>
122
126
{
123
127
template<typename SharedMemoryAdaptor>
124
128
static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
@@ -136,27 +140,27 @@ struct FFT<2,false, Scalar, device_capabilities>
136
140
// Compute the indices only once
137
141
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
138
142
const uint32_t loIx = threadID;
139
- const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
143
+ const uint32_t hiIx = WorkgroupSize | loIx;
140
144
141
145
// Read lo, hi values from global memory
142
146
complex_t<Scalar> lo, hi;
143
147
accessor.get (loIx, lo);
144
148
accessor.get (hiIx, hi);
145
149
146
150
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
147
- if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
151
+ if (WorkgroupSize > glsl::gl_SubgroupSize ())
148
152
{
149
153
// Set up the memory adaptor
150
- using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1 ,_NBL_HLSL_WORKGROUP_SIZE_ >;
154
+ using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1 ,WorkgroupSize >;
151
155
adaptor_t sharedmemAdaptor;
152
156
sharedmemAdaptor.accessor = sharedmemAccessor;
153
157
154
158
// special first iteration
155
- hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_ ), lo, hi);
159
+ hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID, WorkgroupSize ), lo, hi);
156
160
157
161
// Run bigger steps until Subgroup-sized
158
162
[unroll]
159
- for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1 ; stride > glsl::gl_SubgroupSize (); stride >>= 1 )
163
+ for (uint32_t stride = WorkgroupSize >> 1 ; stride > glsl::gl_SubgroupSize (); stride >>= 1 )
160
164
{
161
165
FFT_loop< adaptor_t >(stride, lo, hi, threadID, sharedmemAdaptor);
162
166
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
@@ -181,8 +185,8 @@ struct FFT<2,false, Scalar, device_capabilities>
181
185
182
186
183
187
// 2 items per invocation inverse specialization
184
- template<typename Scalar, class device_capabilities>
185
- struct FFT<2 ,true , Scalar, device_capabilities>
188
+ template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
189
+ struct FFT<2 ,true , WorkgroupSize, Scalar, device_capabilities>
186
190
{
187
191
template<typename SharedMemoryAdaptor>
188
192
static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
@@ -200,7 +204,7 @@ struct FFT<2,true, Scalar, device_capabilities>
200
204
// Compute the indices only once
201
205
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
202
206
const uint32_t loIx = threadID;
203
- const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
207
+ const uint32_t hiIx = WorkgroupSize | loIx;
204
208
205
209
// Read lo, hi values from global memory
206
210
complex_t<Scalar> lo, hi;
@@ -211,10 +215,10 @@ struct FFT<2,true, Scalar, device_capabilities>
211
215
subgroup::FFT<true , Scalar, device_capabilities>::__call (lo, hi);
212
216
213
217
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
214
- if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
218
+ if (WorkgroupSize > glsl::gl_SubgroupSize ())
215
219
{
216
220
// Set up the memory adaptor
217
- using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1 ,_NBL_HLSL_WORKGROUP_SIZE_ >;
221
+ using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1 ,WorkgroupSize >;
218
222
adaptor_t sharedmemAdaptor;
219
223
sharedmemAdaptor.accessor = sharedmemAccessor;
220
224
@@ -223,18 +227,18 @@ struct FFT<2,true, Scalar, device_capabilities>
223
227
224
228
// The bigger steps
225
229
[unroll]
226
- for (uint32_t stride = glsl::gl_SubgroupSize () << 1 ; stride < _NBL_HLSL_WORKGROUP_SIZE_ ; stride <<= 1 )
230
+ for (uint32_t stride = glsl::gl_SubgroupSize () << 1 ; stride < WorkgroupSize ; stride <<= 1 )
227
231
{
228
232
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
229
233
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
230
234
FFT_loop< adaptor_t >(stride, lo, hi, threadID, sharedmemAdaptor);
231
235
}
232
236
233
237
// special last iteration
234
- hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_ ), lo, hi);
238
+ hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID, WorkgroupSize ), lo, hi);
235
239
divides_assign< complex_t<Scalar> > divAss;
236
- divAss (lo, Scalar (_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ()));
237
- divAss (hi, Scalar (_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ()));
240
+ divAss (lo, Scalar (WorkgroupSize / glsl::gl_SubgroupSize ()));
241
+ divAss (hi, Scalar (WorkgroupSize / glsl::gl_SubgroupSize ()));
238
242
239
243
// Remember to update the accessor's state
240
244
sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -247,17 +251,17 @@ struct FFT<2,true, Scalar, device_capabilities>
247
251
};
248
252
249
253
// Forward FFT
250
- template<uint32_t K, typename Scalar, class device_capabilities>
251
- struct FFT<K, false , Scalar, device_capabilities>
254
+ template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
255
+ struct FFT<K, false , WorkgroupSize, Scalar, device_capabilities>
252
256
{
253
257
template<typename Accessor, typename SharedMemoryAccessor>
254
258
static enable_if_t< (mpl::is_pot_v<K> && K > 2 ), void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
255
259
{
256
260
[unroll]
257
- for (uint32_t stride = (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_ ; stride > _NBL_HLSL_WORKGROUP_SIZE_ ; stride >>= 1 )
261
+ for (uint32_t stride = (K / 2 ) * WorkgroupSize ; stride > WorkgroupSize ; stride >>= 1 )
258
262
{
259
263
[unroll]
260
- for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_ ; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_ )
264
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * WorkgroupSize ; virtualThreadID += WorkgroupSize )
261
265
{
262
266
const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
263
267
const uint32_t hiIx = loIx | stride;
@@ -282,16 +286,16 @@ struct FFT<K, false, Scalar, device_capabilities>
282
286
{
283
287
if (k)
284
288
sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
285
- offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_ *k;
286
- FFT<2 ,false , Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
289
+ offsetAccessor.offset = WorkgroupSize *k;
290
+ FFT<2 ,false , WorkgroupSize, Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
287
291
}
288
292
accessor = offsetAccessor.accessor;
289
293
}
290
294
};
291
295
292
296
// Inverse FFT
293
- template<uint32_t K, typename Scalar, class device_capabilities>
294
- struct FFT<K, true , Scalar, device_capabilities>
297
+ template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
298
+ struct FFT<K, true , WorkgroupSize, Scalar, device_capabilities>
295
299
{
296
300
template<typename Accessor, typename SharedMemoryAccessor>
297
301
static enable_if_t< (mpl::is_pot_v<K> && K > 2 ), void > __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
@@ -304,17 +308,17 @@ struct FFT<K, true, Scalar, device_capabilities>
304
308
{
305
309
if (k)
306
310
sharedmemAccessor.workgroupExecutionAndMemoryBarrier ();
307
- offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_ *k;
308
- FFT<2 ,true , Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
311
+ offsetAccessor.offset = WorkgroupSize *k;
312
+ FFT<2 ,true , WorkgroupSize, Scalar, device_capabilities>::template __call (offsetAccessor,sharedmemAccessor);
309
313
}
310
314
accessor = offsetAccessor.accessor;
311
315
312
316
[unroll]
313
- for (uint32_t stride = 2 * _NBL_HLSL_WORKGROUP_SIZE_ ; stride < K * _NBL_HLSL_WORKGROUP_SIZE_ ; stride <<= 1 )
317
+ for (uint32_t stride = 2 * WorkgroupSize ; stride < K * WorkgroupSize ; stride <<= 1 )
314
318
{
315
319
accessor.memoryBarrier (); // no execution barrier just making sure writes propagate to accessor
316
320
[unroll]
317
- for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_ ; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_ )
321
+ for (uint32_t virtualThreadID = SubgroupContiguousIndex (); virtualThreadID < (K / 2 ) * WorkgroupSize ; virtualThreadID += WorkgroupSize )
318
322
{
319
323
const uint32_t loIx = ((virtualThreadID & (~(stride - 1 ))) << 1 ) | (virtualThreadID & (stride - 1 ));
320
324
const uint32_t hiIx = loIx | stride;
@@ -326,7 +330,7 @@ struct FFT<K, true, Scalar, device_capabilities>
326
330
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true ,Scalar>(virtualThreadID & (stride - 1 ), stride), lo,hi);
327
331
328
332
// Divide by special factor at the end
329
- if ( (K / 2 ) * _NBL_HLSL_WORKGROUP_SIZE_ == stride)
333
+ if ( (K / 2 ) * WorkgroupSize == stride)
330
334
{
331
335
divides_assign< complex_t<Scalar> > divAss;
332
336
divAss (lo, K / 2 );
0 commit comments