|
5 | 5 | #include "nbl/builtin/hlsl/workgroup/basic.hlsl"
|
6 | 6 | #include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
|
7 | 7 | #include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
|
| 8 | +#include "nbl/builtin/hlsl/mpl.hlsl" |
| 9 | +#include "nbl/builtin/hlsl/memory_accessor.hlsl" |
8 | 10 |
|
9 | 11 | namespace nbl
|
10 | 12 | {
|
@@ -159,21 +161,96 @@ struct FFT<2,true, Scalar, device_capabilities>
|
159 | 161 |
|
160 | 162 | /*
|
161 | 163 |
|
162 |
| -// then define 4,8,16 in terms of calling the FFT<2> and doing the special radix steps before/after |
163 |
| -template<uint16_t K, bool Inverse, class device_capabilities> |
164 |
| -struct FFT |
| 164 | +// Forward FFT |
| 165 | +template<uint32_t K, typename Scalar, class device_capabilities> |
| 166 | +struct FFT<K,false,device_capabilities> |
165 | 167 | {
|
166 |
| - template<typename Accessor, typename ShaderMemoryAccessor> |
167 |
| - static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(ShaderMemoryAccessor) sharedmemAccessor) |
| 168 | + template<typename Accessor, typename SharedMemoryAccessor> |
| 169 | + static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) |
| 170 | + { |
| 171 | + static const uint32_t virtualThreadCount = K >> 1; |
| 172 | + static const uint16_t passes = mpl::log2<K>::value - 1; |
| 173 | + uint32_t stride = K >> 1; |
| 174 | + [unroll(passes)] |
| 175 | + for (uint16_t pass = 0; pass < passes; pass++) |
| 176 | + { |
| 177 | + [unroll(K/2)] |
| 178 | + for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++) |
| 179 | + { |
| 180 | + const uint32_t virtualThreadID = virtualThread * _NBL_HLSL_WORKGROUP_SIZE_ + SubgroupContiguousIndex(); |
| 181 | +
|
| 182 | + const uint32_t lsb = virtualThread & (stride - 1); |
| 183 | + const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb; |
| 184 | + const uint32_t hiIx = loIx | stride; |
| 185 | + |
| 186 | + complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_); |
| 187 | + complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_); |
| 188 | + |
| 189 | + fft::DIF<Scalar>::radix2(fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi); |
| 190 | + |
| 191 | + accessor.set(loIx, lo); |
| 192 | + accessor.set(hiIx, hi); |
| 193 | + } |
| 194 | + accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor |
| 195 | + stride >>= 1; |
| 196 | + } |
| 197 | + |
| 198 | + // do K/2 small workgroup FFTs |
| 199 | + OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor; |
| 200 | + [unroll(K/2)] |
| 201 | + for (uint32_t k = 0; k < K; k += 2) |
| 202 | + { |
| 203 | + if (k) |
| 204 | + sharedmemAccessor.executionAndMemoryBarrier(); |
| 205 | + offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k; |
| 206 | + FFT<2,false, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor); |
| 207 | + } |
| 208 | + accessor = offsetAccessor.accessor; |
| 209 | + } |
| 210 | +}; |
| 211 | +
|
| 212 | +// Inverse FFT |
| 213 | +template<uint32_t K, typename Scalar, class device_capabilities> |
| 214 | +struct FFT<K,true,device_capabilities> |
| 215 | +{ |
| 216 | + template<typename Accessor, typename SharedMemoryAccessor> |
| 217 | + static enable_if_t<mpl::is_pot_v<K>, void> __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor) |
168 | 218 | {
|
169 |
| - if (!Inverse) |
| 219 | + // do K/2 small workgroup FFTs |
| 220 | + OffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor; |
| 221 | + [unroll(K/2)] |
| 222 | + for (uint32_t k = 0; k < K; k += 2) |
170 | 223 | {
|
171 |
| - ... special steps ... |
| 224 | + if (k) |
| 225 | + sharedmemAccessor.executionAndMemoryBarrier(); |
| 226 | + offsetAccessor.offset = _NBL_HLSL_WORKGROUP_SIZE_*k; |
| 227 | + FFT<2,true, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor); |
172 | 228 | }
|
173 |
| - FFT<2,Inverse,device_capabilities>::template __call<Accessor,SharedMemoryAccessor>(access,sharedMemAccessor); |
174 |
| - if (Inverse) |
| 229 | + accessor = offsetAccessor.accessor; |
| 230 | + |
| 231 | + static const uint32_t virtualThreadCount = K >> 1; |
| 232 | + static const uint16_t passes = mpl::log2<K>::value - 1; |
| 233 | + uint32_t stride = K << 1; |
| 234 | + [unroll(passes)] |
| 235 | + for (uint16_t pass = 0; pass < passes; pass++) |
175 | 236 | {
|
176 |
| - ... special steps ... |
| 237 | + [unroll(K/2)] |
| 238 | + for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++) |
| 239 | + { |
| 240 | + const uint32_t lsb = virtualThread & (stride - 1); |
| 241 | + const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb; |
| 242 | + const uint32_t hiIx = loIx | stride; |
| 243 | + |
| 244 | + complex_t<Scalar> lo = accessor.get(loIx * _NBL_HLSL_WORKGROUP_SIZE_); |
| 245 | + complex_t<Scalar> hi = accessor.get(hiIx * _NBL_HLSL_WORKGROUP_SIZE_); |
| 246 | + |
| 247 | + fft::DIF<Scalar>::radix2(fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi); |
| 248 | + |
| 249 | + accessor.set(loIx, lo); |
| 250 | + accessor.set(hiIx, hi); |
| 251 | + } |
| 252 | + accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor |
| 253 | + stride <<= 1; |
177 | 254 | }
|
178 | 255 | }
|
179 | 256 | };
|
|
0 commit comments