@@ -105,9 +105,10 @@ namespace cp_algo::math::fft {
105105 callback (i, cur);
106106 }
107107 }
108+ template <int step = 1 >
108109 static void exec_on_evals (size_t n, auto &&callback) {
109110 for (size_t i = 0 ; i < n; i++) {
110- callback (i, eval_point (i));
111+ callback (i, eval_point (step * i));
111112 }
112113 }
113114 static auto dot_block (size_t k, cvector const & A, cvector const & B) {
@@ -145,16 +146,36 @@ namespace cp_algo::math::fft {
145146 void ifft () {
146147 size_t n = size ();
147148 for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
148- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
149- k *= 2 * i;
150- vpoint cvrt = {real (rt), -imag (rt)};
151- for (size_t j = k; j < k + i; j += flen) {
152- auto A = get<vpoint>(j) + get<vpoint>(j + i);
153- auto B = get<vpoint>(j) - get<vpoint>(j + i);
154- set (j, A);
155- set (j + i, B * cvrt);
156- }
157- });
149+ if (4 * i <= n) { // radix-4
150+ exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
151+ k *= 4 * i;
152+ vpoint v1 = {real (rt), -imag (rt)};
153+ vpoint v2 = v1 * v1;
154+ vpoint v3 = v1 * v2;
155+ for (size_t j = k; j < k + i; j += flen) {
156+ auto A = get<vpoint>(j);
157+ auto B = get<vpoint>(j + i);
158+ auto C = get<vpoint>(j + 2 * i);
159+ auto D = get<vpoint>(j + 3 * i);
160+ set (j , (A + B + C + D));
161+ set (j + 2 * i, (A + B - C - D) * v2);
162+ set (j + i, (A - B - vpoint (0 , 1 ) * (C - D)) * v1);
163+ set (j + 3 * i, (A - B + vpoint (0 , 1 ) * (C - D)) * v3);
164+ }
165+ });
166+ i *= 2 ;
167+ } else { // radix-2 fallback
168+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
169+ k *= 2 * i;
170+ vpoint cvrt = {real (rt), -imag (rt)};
171+ for (size_t j = k; j < k + i; j += flen) {
172+ auto A = get<vpoint>(j) + get<vpoint>(j + i);
173+ auto B = get<vpoint>(j) - get<vpoint>(j + i);
174+ set (j, A);
175+ set (j + i, B * cvrt);
176+ }
177+ });
178+ }
158179 }
159180 checkpoint (" ifft" );
160181 for (size_t k = 0 ; k < n; k += flen) {
@@ -164,15 +185,35 @@ namespace cp_algo::math::fft {
164185 void fft () {
165186 size_t n = size ();
166187 for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
167- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
168- k *= 2 * i;
169- vpoint vrt = {real (rt), imag (rt)};
170- for (size_t j = k; j < k + i; j += flen) {
171- auto t = get<vpoint>(j + i) * vrt;
172- set (j + i, get<vpoint>(j) - t);
173- set (j, get<vpoint>(j) + t);
174- }
175- });
188+ if (i / 2 >= flen) { // radix-4
189+ i /= 2 ;
190+ exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
191+ k *= 4 * i;
192+ vpoint v1 = {real (rt), imag (rt)};
193+ vpoint v2 = v1 * v1;
194+ vpoint v3 = v1 * v2;
195+ for (size_t j = k; j < k + i; j += flen) {
196+ auto A = get<vpoint>(j);
197+ auto B = get<vpoint>(j + i) * v1;
198+ auto C = get<vpoint>(j + 2 * i) * v2;
199+ auto D = get<vpoint>(j + 3 * i) * v3;
200+ set (j , (A + C) + (B + D));
201+ set (j + i, (A + C) - (B + D));
202+ set (j + 2 * i, (A - C) + vpoint (0 , 1 ) * (B - D));
203+ set (j + 3 * i, (A - C) - vpoint (0 , 1 ) * (B - D));
204+ }
205+ });
206+ } else { // radix-2 fallback
207+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
208+ k *= 2 * i;
209+ vpoint vrt = {real (rt), imag (rt)};
210+ for (size_t j = k; j < k + i; j += flen) {
211+ auto t = get<vpoint>(j + i) * vrt;
212+ set (j + i, get<vpoint>(j) - t);
213+ set (j, get<vpoint>(j) + t);
214+ }
215+ });
216+ }
176217 }
177218 checkpoint (" fft" );
178219 }
0 commit comments