@@ -97,6 +97,68 @@ struct exchangeValues<SharedMemoryAdaptor, float64_t>
97
97
template <typename scalar_t, uint32_t WorkgroupSize>
98
98
NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
99
99
100
+
101
+ template<uint32_t N, uint32_t H>
102
+ enable_if_t<H <= N, uint32_t> bitShiftRightHigher (uint32_t i)
103
+ {
104
+ // Highest H bits are numbered N-1 through N - H
105
+ // N - H is then the middle bit
106
+ // Lowest bits numbered from 0 through N - H - 1
107
+ uint32_t low = i & ((1 << (N - H)) - 1 );
108
+ uint32_t mid = i & (1 << (N - H));
109
+ uint32_t high = i & ~((1 << (N - H + 1 )) - 1 );
110
+
111
+ high >>= 1 ;
112
+ mid <<= H - 1 ;
113
+
114
+ return mid | high | low;
115
+ }
116
+
117
+ template<uint32_t N, uint32_t H>
118
+ enable_if_t<H <= N, uint32_t> bitShiftLeftHigher (uint32_t i)
119
+ {
120
+ // Highest H bits are numbered N-1 through N - H
121
+ // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
122
+ // Lowest bits numbered from 0 through N - H - 1
123
+ uint32_t low = i & ((1 << (N - H)) - 1 );
124
+ uint32_t mid = i & (~((1 << (N - H)) - 1 ) | ~(1 << (N - 1 )));
125
+ uint32_t high = i & (1 << (N - 1 ));
126
+
127
+ mid <<= 1 ;
128
+ high >>= H - 1 ;
129
+
130
+ return mid | high | low;
131
+ }
132
+
133
+ // For an N-bit number, mirrors it around the Nyquist frequency, which for the range [0, 2^N - 1] is precisely 2^(N - 1)
134
+ template<uint32_t N>
135
+ uint32_t mirror (uint32_t i)
136
+ {
137
+ return ((1 << N) - i) & ((1 << N) - 1 )
138
+ }
139
+
140
+ // This function maps the index `idx` in the output array of a Forward FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = output[idx]`
141
+ // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
142
+ template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
143
+ uint32_t getFrequencyAt (uint32_t idx)
144
+ {
145
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
146
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
147
+
148
+ return mirror <FFT_SIZE_LOG_2>(bitShiftRightHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(glsl::bitfieldReverse<uint32_t>(idx) >> (32 - FFT_SIZE_LOG_2)));
149
+ }
150
+
151
+ // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Forward FFT such that `DFT[freqIdx] = output[idx]`
152
+ // It is essentially the inverse of `getFrequencyAt`
153
+ template<uint16_t ElementsPerInvocation, uint32_t WorkgroupSize>
154
+ uint32_t getOutputAt (uint32_t freqIdx)
155
+ {
156
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t ELEMENTS_PER_INVOCATION_LOG_2 = uint32_t (mpl::log2<ElementsPerInvocation>::value);
157
+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFT_SIZE_LOG_2 = ELEMENTS_PER_INVOCATION_LOG_2 + uint32_t (mpl::log2<WorkgroupSize>::value);
158
+
159
+ return glsl::bitfieldReverse<uint32_t>(bitShiftLeftHigher<FFT_SIZE_LOG_2, FFT_SIZE_LOG_2 - ELEMENTS_PER_INVOCATION_LOG_2 + 1 >(mirror <FFT_SIZE_LOG_2>(freqIdx))) >> (32 - FFT_SIZE_LOG_2);
160
+ }
161
+
100
162
} //namespace fft
101
163
102
164
// ----------------------------------- End Utils -----------------------------------------------
0 commit comments