Skip to content

Commit 7ae2f52

Browse files
committed
Adding utils to get right order after FFT
1 parent 0de8b27 commit 7ae2f52

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

include/nbl/builtin/hlsl/fft/utils.hlsl

Whitespace-only changes.

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,68 @@ struct exchangeValues<SharedMemoryAdaptor, float64_t>
9797
template <typename scalar_t, uint32_t WorkgroupSize>
9898
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof(complex_t<scalar_t>) / sizeof(uint32_t)) * WorkgroupSize;
9999

100+
101+
template<uint32_t N, uint32_t H>
102+
enable_if_t<H <= N, uint32_t> bitShiftRightHigher(uint32_t i)
103+
{
104+
// Highest H bits are numbered N-1 through N - H
105+
// N - H is then the middle bit
106+
// Lowest bits numbered from 0 through N - H - 1
107+
uint32_t low = i & ((1 << (N - H)) - 1);
108+
uint32_t mid = i & (1 << (N - H));
109+
uint32_t high = i & ~((1 << (N - H + 1)) - 1);
110+
111+
high >>= 1;
112+
mid <<= H - 1;
113+
114+
return mid | high | low;
115+
}
116+
117+
template<uint32_t N, uint32_t H>
118+
enable_if_t<H <= N, uint32_t> bitShiftLeftHigher(uint32_t i)
119+
{
120+
// Highest H bits are numbered N-1 through N - H
121+
// N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
122+
// Lowest bits numbered from 0 through N - H - 1
123+
uint32_t low = i & ((1 << (N - H)) - 1);
124+
uint32_t mid = i & (~((1 << (N - H)) - 1) | ~(1 << (N - 1)));
125+
uint32_t high = i & (1 << (N - 1));
126+
127+
mid <<= 1;
128+
high >>= H - 1;
129+
130+
return mid | high | low;
131+
}
132+
133+
// For an N-bit number, mirrors it around the Nyquist frequency, which for the range [0, 2^N - 1] is precisely 2^(N - 1)
134+
template<uint32_t N>
135+
uint32_t mirror(uint32_t i)
136+
{
137+
return ((1 << N) - i) & ((1 << N) - 1)
138+
}
139+
140+
// This function maps the index `idx` in the output array of a Forward FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = output[idx]`
141+
// This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
142+
template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
143+
uint32_t getFrequencyAt(uint32_t idx)
144+
{
145+
NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t(mpl::log2<ElementsPerInvocation>::value);
146+
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t(mpl::log2<WorkgroupSize>::value);
147+
148+
return mirror<FFT_SIZE_LOG_2>(bitShiftRightHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1>(glsl::bitfieldReverse<uint32_t>(idx) >> (32 - FFT_SIZE_LOG_2)));
149+
}
150+
151+
// This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Forward FFT such that `DFT[freqIdx] = output[idx]`
152+
// It is essentially the inverse of `getFrequencyAt`
153+
template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
154+
uint32_t getOutputAt(uint32_t freqIdx)
155+
{
156+
NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t(mpl::log2<ElementsPerInvocation>::value);
157+
NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t(mpl::log2<WorkgroupSize>::value);
158+
159+
return glsl::bitfieldReverse<uint32_t>(bitShiftLeftHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1>(mirror<FFT_SIZE_LOG_2>(freqIdx))) >> (32 - FFT_SIZE_LOG_2);
160+
}
161+
100162
} //namespace fft
101163

102164
// ----------------------------------- End Utils -----------------------------------------------

0 commit comments

Comments
 (0)