11#ifndef CP_ALGO_MATH_CVECTOR_HPP
22#define CP_ALGO_MATH_CVECTOR_HPP
3- #include < algorithm>
4- #include < cassert>
5- #include < complex>
6- #include < vector>
3+ #include " ../util/complex.hpp"
4+ #include " ../util/checkpoint.hpp"
5+ #include < experimental/simd>
76#include < ranges>
87namespace cp_algo ::math::fft {
98 using ftype = double ;
10- static constexpr size_t bytes = 32 ;
11- static constexpr size_t flen = bytes / sizeof (ftype);
12- using point = std::complex <ftype>;
13- using vftype [[gnu::vector_size(bytes)]] = ftype;
14- using vpoint = std::complex <vftype>;
9+ using point = complex <ftype>;
10+ using vftype = std::experimental::native_simd<ftype>;
11+ using vpoint = complex <vftype>;
12+ static constexpr size_t flen = vftype::size();
1513
16- #define WITH_IV (...) \
17- [&]<size_t ... i>(std::index_sequence<i...>) { \
18- return __VA_ARGS__; \
19- }(std::make_index_sequence<flen>());
20-
21- template <typename ft>
22- constexpr ft to_ft (auto x) {
23- return ft{} + x;
24- }
25- template <typename pt>
26- constexpr pt to_pt (point r) {
27- using ft = std::conditional_t <std::is_same_v<point, pt>, ftype, vftype>;
28- return {to_ft<ft>(r.real ()), to_ft<ft>(r.imag ())};
29- }
3014 struct cvector {
31- static constexpr size_t pre_roots = 1 << 17 ;
15+ static constexpr size_t pre_roots = 1 << 15 ;
3216 std::vector<vftype> x, y;
3317 cvector (size_t n) {
3418 n = std::max (flen, std::bit_ceil (n));
3519 x.resize (n / flen);
3620 y.resize (n / flen);
21+ checkpoint (" cvector create" );
3722 }
3823 template <class pt = point>
3924 void set (size_t k, pt t) {
@@ -60,132 +45,147 @@ namespace cp_algo::math::fft {
6045 size_t size () const {
6146 return flen * std::size (x);
6247 }
48+
49+
50+ static auto dot_block (size_t k, cvector const & A, cvector const & B) {
51+ auto rt = eval_point (k / flen / 2 );
52+ if (k / flen % 2 ) {
53+ rt = -rt;
54+ }
55+ auto [Bvx, Bvy] = B.vget (k);
56+ auto [Brvx, Brvy] = vpoint (Bvx, Bvy) * vpoint (real (rt), imag (rt));
57+ auto [Ax, Ay] = A.vget (k);
58+ ftype Bx[2 * flen], By[2 * flen];
59+ Bvx.copy_to (Bx + flen, std::experimental::vector_aligned);
60+ Bvy.copy_to (By + flen, std::experimental::vector_aligned);
61+ Brvx.copy_to (Bx, std::experimental::vector_aligned);
62+ Brvy.copy_to (By, std::experimental::vector_aligned);
63+ vpoint res = {0 , 0 };
64+ for (size_t i = 0 ; i < flen; i++) {
65+ vftype Bsx, Bsy;
66+ Bsx.copy_from (Bx + flen - i, std::experimental::element_aligned);
67+ Bsy.copy_from (By + flen - i, std::experimental::element_aligned);
68+ res += vpoint (Ax[i], Ay[i]) * vpoint (Bsx, Bsy);
69+ }
70+ return res;
71+ }
72+
6373 void dot (cvector const & t) {
64- size_t n = size ();
74+ size_t n = this -> size ();
6575 for (size_t k = 0 ; k < n; k += flen) {
66- set (k, get<vpoint>(k) * t. get <vpoint>(k ));
76+ set (k, dot_block (k, * this , t ));
6777 }
78+ checkpoint (" dot" );
6879 }
69- static const cvector roots;
70- template <class pt = point>
71- static pt root (size_t n, size_t k) {
72- if (n < pre_roots) {
73- return roots.get <pt>(n + k);
80+ static const cvector roots, evalp;
81+ static std::array<size_t , pre_roots> eval_args;
82+
83+ template <bool precalc = false >
84+ static size_t eval_arg (size_t n) {
85+ if (n < pre_roots && !precalc) {
86+ return eval_args[n];
87+ } else if (n == 0 ) {
88+ return 0 ;
7489 } else {
75- auto arg = std::numbers::pi / ftype (n);
76- if constexpr (std::is_same_v<pt, point>) {
77- return {cos (ftype (k) * arg), sin (ftype (k) * arg)};
78- } else {
79- return WITH_IV (pt{vftype{cos (ftype (k + i) * arg)...},
80- vftype{sin (ftype (k + i) * arg)...}});
81- }
90+ return eval_arg (n / 2 ) | (n & 1 ) << (std::bit_width (n) - 1 );
8291 }
8392 }
84- template <class pt = point>
93+ template < bool precalc = false >
94+ static auto root (size_t n, size_t k) {
95+ if (n < pre_roots && !precalc) {
96+ return roots.get (n + k);
97+ } else {
98+ return polar (1 ., std::numbers::pi / (ftype)n * (ftype)k);
99+ }
100+ }
101+ template < bool precalc = false >
102+ static point eval_point (size_t n) {
103+ if (n < pre_roots && !precalc) {
104+ return evalp.get (n);
105+ } else if (n == 0 ) {
106+ return 1 ;
107+ } else {
108+ size_t N = std::bit_floor (n);
109+ return root (2 * N, eval_arg (n));
110+ }
111+ }
112+
113+ template <bool precalc = false >
85114 static void exec_on_roots (size_t n, size_t m, auto &&callback) {
86- size_t step = sizeof (pt) / sizeof (point);
87- pt cur;
88- pt arg = to_pt<pt>(root<point>(n, step));
89- for (size_t i = 0 ; i < m; i += step) {
90- if (i % 64 == 0 || n < pre_roots) {
91- cur = root<pt>(n, i);
115+ point cur;
116+ point arg = root<precalc>(n, 1 );
117+ for (size_t i = 0 ; i < m; i++) {
118+ if (precalc || i % 32 == 0 || n < pre_roots) {
119+ cur = root<precalc>(n, i);
92120 } else {
93121 cur *= arg;
94122 }
95123 callback (i, cur);
96124 }
97125 }
126+ static void exec_on_evals (size_t n, auto &&callback) {
127+ for (size_t i = 0 ; i < n; i++) {
128+ callback (i, eval_point (i));
129+ }
130+ }
98131
99132 void ifft () {
100133 size_t n = size ();
101- for (size_t i = 1 ; i < n; i *= 2 ) {
102- for (size_t j = 0 ; j < n; j += 2 * i) {
103- auto butterfly = [&]<class pt >(size_t k, pt rt) {
104- k += j;
105- auto t = get<pt>(k + i) * conj (rt);
106- set (k + i, get<pt>(k) - t);
107- set (k, get<pt>(k) + t);
108- };
109- if (2 * i <= flen) {
110- exec_on_roots (i, i, butterfly);
111- } else {
112- exec_on_roots<vpoint>(i, i, butterfly);
134+ for (size_t i = flen; i <= n / 2 ; i *= 2 ) {
135+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
136+ k *= 2 * i;
137+ vpoint vrt = {real (rt), imag (rt)};
138+ for (size_t j = k; j < k + i; j += flen) {
139+ auto A = get<vpoint>(j) + get<vpoint>(j + i);
140+ auto B = get<vpoint>(j) - get<vpoint>(j + i);
141+ set (j, A);
142+ set (j + i, B * conj (vrt));
113143 }
114- }
144+ });
115145 }
146+ checkpoint (" ifft" );
116147 for (size_t k = 0 ; k < n; k += flen) {
117- set (k, get<vpoint>(k) /= to_pt<vpoint> (ftype (n) ));
148+ set (k, get<vpoint>(k) /= (ftype)(n / flen ));
118149 }
119150 }
120151 void fft () {
121152 size_t n = size ();
122- for (size_t i = n / 2 ; i >= 1 ; i /= 2 ) {
123- for (size_t j = 0 ; j < n; j += 2 * i) {
124- auto butterfly = [&]<class pt >(size_t k, pt rt) {
125- k += j;
126- auto A = get<pt>(k) + get<pt>(k + i);
127- auto B = get<pt>(k) - get<pt>(k + i);
128- set (k, A);
129- set (k + i, B * rt);
130- };
131- if (2 * i <= flen) {
132- exec_on_roots (i, i, butterfly);
133- } else {
134- exec_on_roots<vpoint>(i, i, butterfly);
153+ for (size_t i = n / 2 ; i >= flen; i /= 2 ) {
154+ exec_on_evals (n / (2 * i), [&](size_t k, point rt) {
155+ k *= 2 * i;
156+ vpoint vrt = {real (rt), imag (rt)};
157+ for (size_t j = k; j < k + i; j += flen) {
158+ auto t = get<vpoint>(j + i) * vrt;
159+ set (j + i, get<vpoint>(j) - t);
160+ set (j, get<vpoint>(j) + t);
135161 }
136- }
162+ });
137163 }
164+ checkpoint (" fft" );
138165 }
139166 };
167+ std::array<size_t , cvector::pre_roots> cvector::eval_args = []() {
168+ std::array<size_t , pre_roots> res = {};
169+ for (size_t i = 1 ; i < pre_roots; i++) {
170+ res[i] = res[i >> 1 ] | (i & 1 ) << (std::bit_width (i) - 1 );
171+ }
172+ return res;
173+ }();
140174 const cvector cvector::roots = []() {
141175 cvector res (pre_roots);
142176 for (size_t n = 1 ; n < res.size (); n *= 2 ) {
143- auto base = std::polar (1 ., std::numbers::pi / ftype (n));
144- point cur = 1 ;
145- for (size_t k = 0 ; k < n; k++) {
146- if ((k & 15 ) == 0 ) {
147- cur = std::polar (1 ., std::numbers::pi * ftype (k) / ftype (n));
148- }
149- res.set (n + k, cur);
150- cur *= base;
151- }
177+ cvector::exec_on_roots<true >(n, n, [&](size_t k, auto rt) {
178+ res.set (n + k, rt);
179+ });
152180 }
153181 return res;
154182 }();
155-
156- template <typename base>
157- struct dft {
158- cvector A;
159-
160- dft (std::vector<base> const & a, size_t n): A(n) {
161- for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
162- A.set (i, a[i]);
163- }
164- if (n) {
165- A.fft ();
166- }
167- }
168-
169- std::vector<base> operator *= (dft const & B) {
170- assert (A.size () == B.A .size ());
171- size_t n = A.size ();
172- if (!n) {
173- return std::vector<base>();
174- }
175- A.dot (B.A );
176- A.ifft ();
177- std::vector<base> res (n);
178- for (size_t k = 0 ; k < n; k++) {
179- res[k] = A.get (k);
180- }
181- return res;
182- }
183-
184- auto operator * (dft const & B) const {
185- return dft (*this ) *= B;
183+ const cvector cvector::evalp = []() {
184+ cvector res (pre_roots);
185+ for (size_t n = 0 ; n < res.size (); n++) {
186+ res.set (n, cvector::eval_point<true >(n));
186187 }
187-
188- point operator [](int i) const {return A.get (i);}
189- };
188+ return res;
189+ }();
190190}
191191#endif // CP_ALGO_MATH_CVECTOR_HPP
0 commit comments