Skip to content

Commit b3e889e

Browse files
committed
Changes workgroup FFT to control workgroup size via a template parameter and not a define
1 parent 5e7c522 commit b3e889e

File tree

1 file changed

+36
-32
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+36
-32
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,19 @@ struct exchangeValues<SharedMemoryAdaptor, float64_t>
9393
}
9494
};
9595

96+
// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
97+
template <typename scalar_t, uint32_t WorkgroupSize>
98+
NBL_CONSTEXPR uint32_t sharedMemSize = 2 * WorkgroupSize * (sizeof(scalar_t) / sizeof(uint32_t));
99+
96100
} //namespace fft
97101

98102
// ----------------------------------- End Utils -----------------------------------------------
99103

100-
template<uint16_t ElementsPerInvocation, bool Inverse, typename Scalar, class device_capabilities=void>
104+
template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void>
101105
struct FFT;
102106

103107
// For the FFT methods below, we assume:
104-
// - Accessor is a global memory accessor to an array fitting 2 * _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
108+
// - Accessor is a global memory accessor to an array fitting 2 * WorkgroupSize elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
105109
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. The arrays it accesses with `get` and `set` can optionally be
106110
// different, if you don't want the FFT to be done in-place.
107111
// The Accessor MUST provide the following methods:
@@ -110,15 +114,15 @@ struct FFT;
110114
// * void memoryBarrier();
111115
// You might optionally want to provide a `workgroupExecutionAndMemoryBarrier()` method on it to wait on to be sure the whole FFT pass is done
112116

113-
// - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * _NBL_HLSL_WORKGROUP_SIZE_`.
117+
// - SharedMemoryAccessor accesses a workgroup-shared memory array of size `2 * sizeof(Scalar) * WorkgroupSize`.
114118
// The SharedMemoryAccessor MUST provide the following methods:
115119
// * void get(uint32_t index, inout uint32_t value);
116120
// * void set(uint32_t index, in uint32_t value);
117121
// * void workgroupExecutionAndMemoryBarrier();
118122

119123
// 2 items per invocation forward specialization
120-
template<typename Scalar, class device_capabilities>
121-
struct FFT<2,false, Scalar, device_capabilities>
124+
template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
125+
struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
122126
{
123127
template<typename SharedMemoryAdaptor>
124128
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
@@ -136,27 +140,27 @@ struct FFT<2,false, Scalar, device_capabilities>
136140
// Compute the indices only once
137141
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
138142
const uint32_t loIx = threadID;
139-
const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
143+
const uint32_t hiIx = WorkgroupSize | loIx;
140144

141145
// Read lo, hi values from global memory
142146
complex_t<Scalar> lo, hi;
143147
accessor.get(loIx, lo);
144148
accessor.get(hiIx, hi);
145149

146150
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
147-
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
151+
if (WorkgroupSize > glsl::gl_SubgroupSize())
148152
{
149153
// Set up the memory adaptor
150-
using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1,_NBL_HLSL_WORKGROUP_SIZE_>;
154+
using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1,WorkgroupSize>;
151155
adaptor_t sharedmemAdaptor;
152156
sharedmemAdaptor.accessor = sharedmemAccessor;
153157

154158
// special first iteration
155-
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
159+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID, WorkgroupSize), lo, hi);
156160

157161
// Run bigger steps until Subgroup-sized
158162
[unroll]
159-
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1; stride > glsl::gl_SubgroupSize(); stride >>= 1)
163+
for (uint32_t stride = WorkgroupSize >> 1; stride > glsl::gl_SubgroupSize(); stride >>= 1)
160164
{
161165
FFT_loop< adaptor_t >(stride, lo, hi, threadID, sharedmemAdaptor);
162166
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
@@ -181,8 +185,8 @@ struct FFT<2,false, Scalar, device_capabilities>
181185

182186

183187
// 2 items per invocation inverse specialization
184-
template<typename Scalar, class device_capabilities>
185-
struct FFT<2,true, Scalar, device_capabilities>
188+
template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
189+
struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
186190
{
187191
template<typename SharedMemoryAdaptor>
188192
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
@@ -200,7 +204,7 @@ struct FFT<2,true, Scalar, device_capabilities>
200204
// Compute the indices only once
201205
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
202206
const uint32_t loIx = threadID;
203-
const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
207+
const uint32_t hiIx = WorkgroupSize | loIx;
204208

205209
// Read lo, hi values from global memory
206210
complex_t<Scalar> lo, hi;
@@ -211,10 +215,10 @@ struct FFT<2,true, Scalar, device_capabilities>
211215
subgroup::FFT<true, Scalar, device_capabilities>::__call(lo, hi);
212216

213217
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
214-
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
218+
if (WorkgroupSize > glsl::gl_SubgroupSize())
215219
{
216220
// Set up the memory adaptor
217-
using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1,_NBL_HLSL_WORKGROUP_SIZE_>;
221+
using adaptor_t = accessor_adaptors::StructureOfArrays<SharedMemoryAccessor,uint32_t,uint32_t,1,WorkgroupSize>;
218222
adaptor_t sharedmemAdaptor;
219223
sharedmemAdaptor.accessor = sharedmemAccessor;
220224

@@ -223,18 +227,18 @@ struct FFT<2,true, Scalar, device_capabilities>
223227

224228
// The bigger steps
225229
[unroll]
226-
for (uint32_t stride = glsl::gl_SubgroupSize() << 1; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
230+
for (uint32_t stride = glsl::gl_SubgroupSize() << 1; stride < WorkgroupSize; stride <<= 1)
227231
{
228232
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
229233
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
230234
FFT_loop< adaptor_t >(stride, lo, hi, threadID, sharedmemAdaptor);
231235
}
232236

233237
// special last iteration
234-
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
238+
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID, WorkgroupSize), lo, hi);
235239
divides_assign< complex_t<Scalar> > divAss;
236-
divAss(lo, Scalar(_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize()));
237-
divAss(hi, Scalar(_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize()));
240+
divAss(lo, Scalar(WorkgroupSize / glsl::gl_SubgroupSize()));
241+
divAss(hi, Scalar(WorkgroupSize / glsl::gl_SubgroupSize()));
238242

239243
// Remember to update the accessor's state
240244
sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -247,17 +251,17 @@ struct FFT<2,true, Scalar, device_capabilities>
247251
};
248252

249253
// Forward FFT
250-
template<uint32_t K, typename Scalar, class device_capabilities>
251-
struct FFT<K, false, Scalar, device_capabilities>
254+
template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
255+
struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
252256
{
253257
template<typename Accessor, typename SharedMemoryAccessor>
254258
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
255259
{
256260
[unroll]
257-
for (uint32_t stride = (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1)
261+
for (uint32_t stride = (K / 2) * WorkgroupSize; stride > WorkgroupSize; stride >>= 1)
258262
{
259263
[unroll]
260-
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
264+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
261265
{
262266
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
263267
const uint32_t hiIx = loIx | stride;
@@ -282,16 +286,16 @@ struct FFT<K, false, Scalar, device_capabilities>
282286
{
283287
if (k)
284288
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
285-
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
286-
FFT<2,false, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
289+
offsetAccessor.offset = WorkgroupSize*k;
290+
FFT<2,false, WorkgroupSize, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
287291
}
288292
accessor = offsetAccessor.accessor;
289293
}
290294
};
291295

292296
// Inverse FFT
293-
template<uint32_t K, typename Scalar, class device_capabilities>
294-
struct FFT<K, true, Scalar, device_capabilities>
297+
template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
298+
struct FFT<K, true, WorkgroupSize, Scalar, device_capabilities>
295299
{
296300
template<typename Accessor, typename SharedMemoryAccessor>
297301
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
@@ -304,17 +308,17 @@ struct FFT<K, true, Scalar, device_capabilities>
304308
{
305309
if (k)
306310
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
307-
offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k;
308-
FFT<2,true, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
311+
offsetAccessor.offset = WorkgroupSize*k;
312+
FFT<2,true, WorkgroupSize, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
309313
}
310314
accessor = offsetAccessor.accessor;
311315

312316
[unroll]
313-
for (uint32_t stride = 2 * _NBL_HLSL_WORKGROUP_SIZE_; stride < K * _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
317+
for (uint32_t stride = 2 * WorkgroupSize; stride < K * WorkgroupSize; stride <<= 1)
314318
{
315319
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
316320
[unroll]
317-
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
321+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * WorkgroupSize; virtualThreadID += WorkgroupSize)
318322
{
319323
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
320324
const uint32_t hiIx = loIx | stride;
@@ -326,7 +330,7 @@ struct FFT<K, true, Scalar, device_capabilities>
326330
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride), lo,hi);
327331

328332
// Divide by special factor at the end
329-
if ( (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_ == stride)
333+
if ( (K / 2) * WorkgroupSize == stride)
330334
{
331335
divides_assign< complex_t<Scalar> > divAss;
332336
divAss(lo, K / 2);

0 commit comments

Comments
 (0)