@@ -24,14 +24,6 @@ namespace cp_algo::math::fft {
2424 r.resize (n / flen);
2525 checkpoint (" cvector create" );
2626 }
27- cvector (cvector const & t) {
28- r.resize (t.r .size ());
29- for (size_t i = 0 ; i < r.size (); i++) {
30- r[i] = {vftype (t.r [i].real ()), vftype (t.r [i].imag ())};
31- }
32- checkpoint (" cvector copy" );
33- }
34- cvector (cvector&& t) = delete ;
3527
3628 vpoint& at (size_t k) {return r[k / flen];}
3729 vpoint at (size_t k) const {return r[k / flen];}
@@ -63,74 +55,53 @@ namespace cp_algo::math::fft {
6355 return eval_arg (n / 2 ) | (n & 1 ) << (std::bit_width (n) - 1 );
6456 }
6557 }
66- static auto root (size_t n, size_t k) {
67- if (n < pre_roots) {
68- return roots[n + k];
69- } else if (k % 2 == 0 ) {
70- return root (n / 2 , k / 2 );
71- } else {
72- return polar (1 ., std::numbers::pi / (ftype)n * (ftype)k);
73- }
74- }
7558 static point eval_point (size_t n) {
7659 if (n % 2 ) {
77- return eval_point (n - 1 ) * point (0 , 1 );
78- } else if (n / 2 < pre_evals) {
79- return evalp[n / 2 ];
60+ return -eval_point (n - 1 );
61+ } else if (n % 4 ) {
62+ return eval_point (n - 2 ) * point (0 , 1 );
63+ } else if (n / 4 < pre_evals) {
64+ return evalp[n / 4 ];
8065 } else {
81- return root ( 2 * std::bit_floor (n), eval_arg (n));
66+ return polar ( 1 ., std::numbers::pi / (ftype) std::bit_floor (n) * (ftype) eval_arg (n));
8267 }
8368 }
84- static void exec_on_roots (size_t n, size_t m, auto &&callback) {
85- point cur = {1 , 0 };
86- point arg = root (n, 1 );
87- for (size_t i = 0 ; i < m; i++) {
88- callback (i, cur);
89- if (i % 64 == 63 ) {
90- cur = root (n / 64 , i / 64 + 1 );
91- } else {
92- cur *= arg;
93- }
94- }
69+ static point root (size_t n) {
70+ return polar (1 ., 2 . * std::numbers::pi / (ftype)n);
9571 }
96- template <int step = 1 >
72+ template <int step>
9773 static void exec_on_evals (size_t n, auto &&callback) {
74+ point factor = root (4 * step * n);
9875 for (size_t i = 0 ; i < n; i++) {
99- callback (i, eval_point (step * i));
100- }
101- }
102- static auto dot_block (size_t k, cvector const & A, cvector const & B) {
103- auto rt = eval_point (k / flen / 2 );
104- if (k / flen % 2 ) {
105- rt = -rt;
76+ callback (i, factor * eval_point (step * i));
10677 }
107- auto [Ax, Ay] = A.at (k);
108- auto Bv = B.at (k);
109- vpoint res = vz;
110- for (size_t i = 0 ; i < flen; i++) {
111- res += vpoint (vz + Ax[i], vz + Ay[i]) * Bv;
112- real (Bv) = __builtin_shufflevector (real (Bv), real (Bv), 3 , 0 , 1 , 2 );
113- imag (Bv) = __builtin_shufflevector (imag (Bv), imag (Bv), 3 , 0 , 1 , 2 );
114- auto x = real (Bv)[0 ], y = imag (Bv)[0 ];
115- real (Bv)[0 ] = x * real (rt) - y * imag (rt);
116- imag (Bv)[0 ] = x * imag (rt) + y * real (rt);
117- }
118- return res;
11978 }
12079
12180 void dot (cvector const & t) {
12281 size_t n = this ->size ();
123- for (size_t k = 0 ; k < n; k += flen) {
124- set (k, dot_block (k, *this , t));
125- }
82+ exec_on_evals<1 >(n / flen, [&](size_t k, point rt) {
83+ k *= flen;
84+ auto [Ax, Ay] = at (k);
85+ auto Bv = t.at (k);
86+ vpoint res = vz;
87+ for (size_t i = 0 ; i < flen; i++) {
88+ res += vpoint (vz + Ax[i], vz + Ay[i]) * Bv;
89+ real (Bv) = __builtin_shufflevector (real (Bv), real (Bv), 3 , 0 , 1 , 2 );
90+ imag (Bv) = __builtin_shufflevector (imag (Bv), imag (Bv), 3 , 0 , 1 , 2 );
91+ auto x = real (Bv)[0 ], y = imag (Bv)[0 ];
92+ real (Bv)[0 ] = x * real (rt) - y * imag (rt);
93+ imag (Bv)[0 ] = x * imag (rt) + y * real (rt);
94+ }
95+ set (k, res);
96+ });
12697 checkpoint (" dot" );
12798 }
12899
129100 void ifft () {
130101 size_t n = size ();
131102 for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
132103 if (4 * i <= n) { // radix-4
133- exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
104+ exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
134105 k *= 4 * i;
135106 vpoint v1 = {vz + real (rt), vz - imag (rt)};
136107 vpoint v2 = v1 * v1;
@@ -148,7 +119,7 @@ namespace cp_algo::math::fft {
148119 });
149120 i *= 2 ;
150121 } else { // radix-2 fallback
151- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
122+ exec_on_evals< 2 > (n / (2 * i), [&](size_t k, point rt) {
152123 k *= 2 * i;
153124 vpoint cvrt = {vz + real (rt), vz - imag (rt)};
154125 for (size_t j = k; j < k + i; j += flen) {
@@ -169,7 +140,7 @@ namespace cp_algo::math::fft {
169140 for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
170141 if (i / 2 >= flen) { // radix-4
171142 i /= 2 ;
172- exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
143+ exec_on_evals<4 >(n / (4 * i), [&](size_t k, point rt) {
173144 k *= 4 * i;
174145 vpoint v1 = {vz + real (rt), vz + imag (rt)};
175146 vpoint v2 = v1 * v1;
@@ -186,7 +157,7 @@ namespace cp_algo::math::fft {
186157 }
187158 });
188159 } else { // radix-2 fallback
189- exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
160+ exec_on_evals< 2 > (n / (2 * i), [&](size_t k, point rt) {
190161 k *= 2 * i;
191162 vpoint vrt = {vz + real (rt), vz + imag (rt)};
192163 for (size_t j = k; j < k + i; j += flen) {
@@ -199,17 +170,7 @@ namespace cp_algo::math::fft {
199170 }
200171 checkpoint (" fft" );
201172 }
202- static constexpr size_t pre_roots = 1 << 14 ;
203173 static constexpr size_t pre_evals = 1 << 16 ;
204- static constexpr std::array<point, pre_roots> roots = []() {
205- std::array<point, pre_roots> res = {};
206- for (size_t n = 1 ; n < res.size (); n *= 2 ) {
207- for (size_t k = 0 ; k < n; k++) {
208- res[n + k] = polar (1 ., std::numbers::pi / ftype (n) * ftype (k));
209- }
210- }
211- return res;
212- }();
213174 static constexpr std::array<size_t , pre_evals> eval_args = []() {
214175 std::array<size_t , pre_evals> res = {};
215176 for (size_t i = 1 ; i < pre_evals; i++) {
0 commit comments