@@ -13,67 +13,40 @@ namespace cp_algo::math::fft {
1313 using point = complex <ftype>;
1414 using vftype [[gnu::vector_size(bytes)]] = ftype;
1515 using vpoint = complex <vftype>;
16- static constexpr vftype fz = {};
16+ static constexpr vftype vz = {};
17+ static constexpr vpoint vi = {vz, vz + 1 };
1718
1819 struct cvector {
19- std::vector<vftype> x, y ;
20+ std::vector<vpoint> r ;
2021 cvector (size_t n) {
2122 n = std::max (flen, std::bit_ceil (n));
22- x.resize (n / flen);
23- y.resize (n / flen);
23+ r.resize (n / flen);
2424 checkpoint (" cvector create" );
2525 }
26+
27+ vpoint& at (size_t k) {return r[k / flen];}
28+ vpoint at (size_t k) const {return r[k / flen];}
2629 template <class pt = point>
2730 void set (size_t k, pt t) {
2831 if constexpr (std::is_same_v<pt, point>) {
29- x [k / flen][k % flen] = real (t);
30- y [k / flen][k % flen] = imag (t);
32+ real (r [k / flen]) [k % flen] = real (t);
33+ imag (r [k / flen]) [k % flen] = imag (t);
3134 } else {
32- x[k / flen] = real (t);
33- y[k / flen] = imag (t);
35+ at (k) = t;
3436 }
3537 }
3638 template <class pt = point>
3739 pt get (size_t k) const {
3840 if constexpr (std::is_same_v<pt, point>) {
39- return {x [k / flen][k % flen], y [k / flen][k % flen]};
41+ return {real (r [k / flen]) [k % flen], imag (r [k / flen]) [k % flen]};
4042 } else {
41- return {x[k / flen], y[k / flen]} ;
43+ return at (k) ;
4244 }
4345 }
44- vpoint vget (size_t k) const {
45- return get<vpoint>(k);
46- }
4746
4847 size_t size () const {
49- return flen * std::size (x );
48+ return flen * std::size (r );
5049 }
51-
52- static constexpr size_t pre_roots = 1 << 16 ;
53- static constexpr std::array<point, pre_roots> roots = []() {
54- std::array<point, pre_roots> res = {};
55- for (size_t n = 1 ; n < res.size (); n *= 2 ) {
56- for (size_t k = 0 ; k < n; k++) {
57- res[n + k] = polar (1 ., std::numbers::pi / ftype (n) * ftype (k));
58- }
59- }
60- return res;
61- }();
62- static constexpr std::array<size_t , pre_roots> eval_args = []() {
63- std::array<size_t , pre_roots> res = {};
64- for (size_t i = 1 ; i < pre_roots; i++) {
65- res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
66- }
67- return res;
68- }();
69- static constexpr std::array<point, pre_roots> evalp = []() {
70- std::array<point, pre_roots> res = {};
71- res[0 ] = 1 ;
72- for (size_t n = 1 ; n < pre_roots; n++) {
73- res[n] = polar (1 ., std::numbers::pi * ftype (eval_args[n]) / ftype (2 * std::bit_floor (n)));
74- }
75- return res;
76- }();
7750 static size_t eval_arg (size_t n) {
7851 if (n < pre_roots) {
7952 return eval_args[n];
@@ -118,15 +91,17 @@ namespace cp_algo::math::fft {
11891 if (k / flen % 2 ) {
11992 rt = -rt;
12093 }
121- auto [Bvx, Bvy] = B.vget (k);
122- auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (fz + real (rt), fz + imag (rt));
123- auto [Ax, Ay] = A.vget (k);
124- vftype Bx[2 ] = {Brvx, Bvx}, By[2 ] = {Brvy, Bvy};
125- vpoint res = {fz, fz};
94+ auto [Ax, Ay] = A.at (k);
95+ auto Bv = B.at (k);
96+ vpoint res = {vz, vz};
12697 for (size_t i = 0 ; i < flen; i++) {
127- auto Bsx = (vftype*)((ftype*)Bx + flen - i);
128- auto Bsy = (vftype*)((ftype*)By + flen - i);
129- res += vpoint (fz + Ax[i], fz + Ay[i]) * vpoint{*Bsx, *Bsy};
98+ res += vpoint (vz + Ax[i], vz + Ay[i]) * Bv;
99+ real (Bv) = __builtin_shufflevector (real (Bv), real (Bv), 3 , 0 , 1 , 2 );
100+ imag (Bv) = __builtin_shufflevector (imag (Bv), imag (Bv), 3 , 0 , 1 , 2 );
101+ auto x = real (Bv)[0 ];
102+ auto y = imag (Bv)[0 ];
103+ real (Bv)[0 ] = x * real (rt) - y * imag (rt);
104+ imag (Bv)[0 ] = x * imag (rt) + y * real (rt);
130105 }
131106 return res;
132107 }
@@ -145,37 +120,36 @@ namespace cp_algo::math::fft {
145120 if (4 * i <= n) { // radix-4
146121 exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
147122 k *= 4 * i;
148- vpoint v1 = {fz + real (rt), fz - imag (rt)};
123+ vpoint v1 = {vz + real (rt), vz - imag (rt)};
149124 vpoint v2 = v1 * v1;
150125 vpoint v3 = v1 * v2;
151126 for (size_t j = k; j < k + i; j += flen) {
152- auto A = get<vpoint> (j);
153- auto B = get<vpoint> (j + i);
154- auto C = get<vpoint> (j + 2 * i);
155- auto D = get<vpoint> (j + 3 * i);
156- set (j , (A + B + C + D) );
157- set (j + 2 * i, (A + B - C - D) * v2) ;
158- set (j + i, (A - B - vpoint (fz, fz + 1 ) * (C - D)) * v1) ;
159- set (j + 3 * i, (A - B + vpoint (fz, fz + 1 ) * (C - D)) * v3) ;
127+ auto A = at (j);
128+ auto B = at (j + i);
129+ auto C = at (j + 2 * i);
130+ auto D = at (j + 3 * i);
131+ at (j) = (A + B + C + D);
132+ at (j + 2 * i) = (A + B - C - D) * v2;
133+ at (j + i) = (A - B - vi * (C - D)) * v1;
134+ at (j + 3 * i) = (A - B + vi * (C - D)) * v3;
160135 }
161136 });
162137 i *= 2 ;
163138 } else { // radix-2 fallback
164139 exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
165140 k *= 2 * i;
166- vpoint cvrt = {fz + real (rt), fz - imag (rt)};
141+ vpoint cvrt = {vz + real (rt), vz - imag (rt)};
167142 for (size_t j = k; j < k + i; j += flen) {
168- auto A = get<vpoint>(j) + get<vpoint>(j + i);
169- auto B = get<vpoint>(j) - get<vpoint>(j + i);
170- set (j, A);
171- set (j + i, B * cvrt);
143+ auto B = at (j) - at (j + i);
144+ at (j) += at (j + i);
145+ at (j + i) = B * cvrt;
172146 }
173147 });
174148 }
175149 }
176150 checkpoint (" ifft" );
177151 for (size_t k = 0 ; k < n; k += flen) {
178- set (k, get<vpoint>(k) /= fz + (ftype)(n / flen));
152+ set (k, get<vpoint>(k) /= vz + (ftype)(n / flen));
179153 }
180154 }
181155 void fft () {
@@ -185,34 +159,59 @@ namespace cp_algo::math::fft {
185159 i /= 2 ;
186160 exec_on_evals<2 >(n / (4 * i), [&](size_t k, point rt) {
187161 k *= 4 * i;
188- vpoint v1 = {fz + real (rt), fz + imag (rt)};
162+ vpoint v1 = {vz + real (rt), vz + imag (rt)};
189163 vpoint v2 = v1 * v1;
190164 vpoint v3 = v1 * v2;
191165 for (size_t j = k; j < k + i; j += flen) {
192- auto A = get<vpoint> (j);
193- auto B = get<vpoint> (j + i) * v1;
194- auto C = get<vpoint> (j + 2 * i) * v2;
195- auto D = get<vpoint> (j + 3 * i) * v3;
196- set (j , (A + C) + (B + D) );
197- set (j + i, (A + C) - (B + D) );
198- set (j + 2 * i, (A - C) + vpoint (fz, fz + 1 ) * (B - D) );
199- set (j + 3 * i, (A - C) - vpoint (fz, fz + 1 ) * (B - D) );
166+ auto A = at (j);
167+ auto B = at (j + i) * v1;
168+ auto C = at (j + 2 * i) * v2;
169+ auto D = at (j + 3 * i) * v3;
170+ at (j) = (A + C) + (B + D);
171+ at (j + i) = (A + C) - (B + D);
172+ at (j + 2 * i) = (A - C) + vi * (B - D);
173+ at (j + 3 * i) = (A - C) - vi * (B - D);
200174 }
201175 });
202176 } else { // radix-2 fallback
203177 exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
204178 k *= 2 * i;
205- vpoint vrt = {fz + real (rt), fz + imag (rt)};
179+ vpoint vrt = {vz + real (rt), vz + imag (rt)};
206180 for (size_t j = k; j < k + i; j += flen) {
207- auto t = get<vpoint> (j + i) * vrt;
208- set (j + i, get<vpoint> (j) - t) ;
209- set (j, get<vpoint>(j ) + t) ;
181+ auto t = at (j + i) * vrt;
182+ at (j + i) = at (j) - t;
183+ at (j ) += t ;
210184 }
211185 });
212186 }
213187 }
214188 checkpoint (" fft" );
215189 }
190+ static constexpr size_t pre_roots = 1 << 16 ;
191+ static constexpr std::array<point, pre_roots> roots = []() {
192+ std::array<point, pre_roots> res = {};
193+ for (size_t n = 1 ; n < res.size (); n *= 2 ) {
194+ for (size_t k = 0 ; k < n; k++) {
195+ res[n + k] = polar (1 ., std::numbers::pi / ftype (n) * ftype (k));
196+ }
197+ }
198+ return res;
199+ }();
200+ static constexpr std::array<size_t , pre_roots> eval_args = []() {
201+ std::array<size_t , pre_roots> res = {};
202+ for (size_t i = 1 ; i < pre_roots; i++) {
203+ res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
204+ }
205+ return res;
206+ }();
207+ static constexpr std::array<point, pre_roots> evalp = []() {
208+ std::array<point, pre_roots> res = {};
209+ res[0 ] = 1 ;
210+ for (size_t n = 1 ; n < pre_roots; n++) {
211+ res[n] = polar (1 ., std::numbers::pi * ftype (eval_args[n]) / ftype (2 * std::bit_floor (n)));
212+ }
213+ return res;
214+ }();
216215 };
217216}
218217#endif // CP_ALGO_MATH_CVECTOR_HPP
0 commit comments