|
1 | 1 | #ifndef CP_ALGO_MATH_CVECTOR_HPP |
2 | 2 | #define CP_ALGO_MATH_CVECTOR_HPP |
3 | 3 | #include <algorithm> |
| 4 | +#include <cassert> |
4 | 5 | #include <complex> |
5 | 6 | #include <vector> |
6 | 7 | #include <ranges> |
7 | 8 | namespace cp_algo::math::fft { |
8 | 9 | using ftype = double; |
| 10 | + static constexpr size_t bytes = 32; |
| 11 | + static constexpr size_t flen = bytes / sizeof(ftype); |
9 | 12 | using point = std::complex<ftype>; |
| 13 | + using vftype [[gnu::vector_size(bytes)]] = ftype; |
| 14 | + using vpoint = std::complex<vftype>; |
10 | 15 |
|
11 | | - struct ftvec: std::vector<point> { |
12 | | - static constexpr size_t pre_roots = 1 << 16; |
13 | | - static constexpr size_t threshold = 32; |
14 | | - ftvec(size_t n) { |
15 | | - this->resize(std::max(threshold, std::bit_ceil(n))); |
| 16 | +#define WITH_IV(...) \ |
| 17 | + [&]<size_t ... i>(std::index_sequence<i...>) { \ |
| 18 | + return __VA_ARGS__; \ |
| 19 | + }(std::make_index_sequence<flen>()); |
| 20 | + |
| 21 | + template<typename ft> |
| 22 | + constexpr ft to_ft(auto x) { |
| 23 | + return ft{} + x; |
| 24 | + } |
| 25 | + template<typename pt> |
| 26 | + constexpr pt to_pt(point r) { |
| 27 | + using ft = std::conditional_t<std::is_same_v<point, pt>, ftype, vftype>; |
| 28 | + return {to_ft<ft>(r.real()), to_ft<ft>(r.imag())}; |
| 29 | + } |
| 30 | + struct cvector { |
| 31 | + static constexpr size_t pre_roots = 1 << 17; |
| 32 | + std::vector<vftype> x, y; |
| 33 | + cvector(size_t n) { |
| 34 | + n = std::max(flen, std::bit_ceil(n)); |
| 35 | + x.resize(n / flen); |
| 36 | + y.resize(n / flen); |
16 | 37 | } |
17 | | - static auto dot_block(size_t k, ftvec const& A, ftvec const& B) { |
18 | | - static std::array<point, 2 * threshold> r; |
19 | | - std::ranges::fill(r, point(0)); |
20 | | - for(size_t i = 0; i < threshold; i++) { |
21 | | - for(size_t j = 0; j < threshold; j++) { |
22 | | - r[i + j] += A[k + i] * B[k + j]; |
23 | | - } |
24 | | - } |
25 | | - auto rt = ftype(k / threshold % 2 ? -1 : 1) * eval_point(k / threshold / 2); |
26 | | - static std::array<point, threshold> res; |
27 | | - for(size_t i = 0; i < threshold; i++) { |
28 | | - res[i] = r[i] + r[i + threshold] * rt; |
| 38 | + template<class pt = point> |
| 39 | + void set(size_t k, pt t) { |
| 40 | + if constexpr(std::is_same_v<pt, point>) { |
| 41 | + x[k / flen][k % flen] = real(t); |
| 42 | + y[k / flen][k % flen] = imag(t); |
| 43 | + } else { |
| 44 | + x[k / flen] = real(t); |
| 45 | + y[k / flen] = imag(t); |
29 | 46 | } |
30 | | - return res; |
31 | 47 | } |
32 | | - |
33 | | - void dot(ftvec const& t) { |
34 | | - size_t n = this->size(); |
35 | | - for(size_t k = 0; k < n; k += threshold) { |
36 | | - std::ranges::copy(dot_block(k, *this, t), this->begin() + k); |
| 48 | + template<class pt = point> |
| 49 | + pt get(size_t k) const { |
| 50 | + if constexpr(std::is_same_v<pt, point>) { |
| 51 | + return {x[k / flen][k % flen], y[k / flen][k % flen]}; |
| 52 | + } else { |
| 53 | + return {x[k / flen], y[k / flen]}; |
37 | 54 | } |
38 | 55 | } |
39 | | - static std::array<point, pre_roots> roots, evalp; |
40 | | - static std::array<size_t, pre_roots> eval_args; |
41 | | - static point root(size_t n, size_t k) { |
42 | | - if(n + k < pre_roots && roots[n + k] != point{}) { |
43 | | - return roots[n + k]; |
44 | | - } |
45 | | - auto res = std::polar(1., std::numbers::pi * ftype(k) / ftype(n)); |
46 | | - if(n + k < pre_roots) { |
47 | | - roots[n + k] = res; |
48 | | - } |
49 | | - return res; |
| 56 | + vpoint vget(size_t k) const { |
| 57 | + return get<vpoint>(k); |
50 | 58 | } |
51 | | - static size_t eval_arg(size_t n) { |
52 | | - if(n < pre_roots && eval_args[n]) { |
53 | | - return eval_args[n]; |
54 | | - } else if(n == 0) { |
55 | | - return 0; |
56 | | - } |
57 | | - auto res = eval_arg(n / 2) | (n & 1) << (std::bit_width(n) - 1); |
58 | | - if(n < pre_roots) { |
59 | | - eval_args[n] = res; |
60 | | - } |
61 | | - return res; |
| 59 | + |
| 60 | + size_t size() const { |
| 61 | + return flen * std::size(x); |
62 | 62 | } |
63 | | - static point eval_point(size_t n) { |
64 | | - if(n < pre_roots && evalp[n] != point{}) { |
65 | | - return evalp[n]; |
66 | | - } else if(n == 0) { |
67 | | - return point(1); |
| 63 | + void dot(cvector const& t) { |
| 64 | + size_t n = size(); |
| 65 | + for(size_t k = 0; k < n; k += flen) { |
| 66 | + set(k, get<vpoint>(k) * t.get<vpoint>(k)); |
68 | 67 | } |
69 | | - auto res = root(2 * std::bit_floor(n), eval_arg(n)); |
| 68 | + } |
| 69 | + static const cvector roots; |
| 70 | + template<class pt = point> |
| 71 | + static pt root(size_t n, size_t k) { |
70 | 72 | if(n < pre_roots) { |
71 | | - evalp[n] = res; |
| 73 | + return roots.get<pt>(n + k); |
| 74 | + } else { |
| 75 | + auto arg = std::numbers::pi / ftype(n); |
| 76 | + if constexpr(std::is_same_v<pt, point>) { |
| 77 | + return {cos(ftype(k) * arg), sin(ftype(k) * arg)}; |
| 78 | + } else { |
| 79 | + return WITH_IV(pt{vftype{cos(ftype(k + i) * arg)...}, |
| 80 | + vftype{sin(ftype(k + i) * arg)...}}); |
| 81 | + } |
72 | 82 | } |
73 | | - return res; |
74 | 83 | } |
| 84 | + template<class pt = point> |
75 | 85 | static void exec_on_roots(size_t n, size_t m, auto &&callback) { |
76 | | - auto step = root(n, 1); |
77 | | - auto rt = point(1); |
78 | | - for(size_t i = 0; i < m; i++) { |
79 | | - if(i % threshold == 0) { |
80 | | - rt = root(n / threshold, i / threshold); |
| 86 | + size_t step = sizeof(pt) / sizeof(point); |
| 87 | + pt cur; |
| 88 | + pt arg = to_pt<pt>(root<point>(n, step)); |
| 89 | + for(size_t i = 0; i < m; i += step) { |
| 90 | + if(i % 64 == 0 || n < pre_roots) { |
| 91 | + cur = root<pt>(n, i); |
| 92 | + } else { |
| 93 | + cur *= arg; |
81 | 94 | } |
82 | | - callback(i, rt); |
83 | | - rt *= step; |
84 | | - } |
85 | | - } |
86 | | - static void exec_on_evals(size_t n, auto &&callback) { |
87 | | - for(size_t i = 0; i < n; i++) { |
88 | | - callback(i, eval_point(i)); |
| 95 | + callback(i, cur); |
89 | 96 | } |
90 | 97 | } |
91 | 98 |
|
92 | 99 | void ifft() { |
93 | | - size_t n = this->size(); |
94 | | - for(size_t half = threshold; half <= n / 2; half *= 2) { |
95 | | - exec_on_evals(n / (2 * half), [&](size_t k, point rt) { |
96 | | - k *= 2 * half; |
97 | | - for(size_t j = k; j < k + half; j++) { |
98 | | - auto A = this->at(j) + this->at(j + half); |
99 | | - auto B = this->at(j) - this->at(j + half); |
100 | | - this->at(j) = A; |
101 | | - this->at(j + half) = B * conj(rt); |
| 100 | + size_t n = size(); |
| 101 | + for(size_t i = 1; i < n; i *= 2) { |
| 102 | + for(size_t j = 0; j < n; j += 2 * i) { |
| 103 | + auto butterfly = [&]<class pt>(size_t k, pt rt) { |
| 104 | + k += j; |
| 105 | + auto t = get<pt>(k + i) * conj(rt); |
| 106 | + set(k + i, get<pt>(k) - t); |
| 107 | + set(k, get<pt>(k) + t); |
| 108 | + }; |
| 109 | + if(2 * i <= flen) { |
| 110 | + exec_on_roots(i, i, butterfly); |
| 111 | + } else { |
| 112 | + exec_on_roots<vpoint>(i, i, butterfly); |
102 | 113 | } |
103 | | - }); |
| 114 | + } |
104 | 115 | } |
105 | | - point ni = point(int(threshold)) / point(int(n)); |
106 | | - for(auto &it: *this) { |
107 | | - it *= ni; |
| 116 | + for(size_t k = 0; k < n; k += flen) { |
| 117 | + set(k, get<vpoint>(k) /= to_pt<vpoint>(ftype(n))); |
108 | 118 | } |
109 | 119 | } |
110 | 120 | void fft() { |
111 | | - size_t n = this->size(); |
112 | | - for(size_t half = n / 2; half >= threshold; half /= 2) { |
113 | | - exec_on_evals(n / (2 * half), [&](size_t k, point rt) { |
114 | | - k *= 2 * half; |
115 | | - for(size_t j = k; j < k + half; j++) { |
116 | | - auto t = this->at(j + half) * rt; |
117 | | - this->at(j + half) = this->at(j) - t; |
118 | | - this->at(j) += t; |
| 121 | + size_t n = size(); |
| 122 | + for(size_t i = n / 2; i >= 1; i /= 2) { |
| 123 | + for(size_t j = 0; j < n; j += 2 * i) { |
| 124 | + auto butterfly = [&]<class pt>(size_t k, pt rt) { |
| 125 | + k += j; |
| 126 | + auto A = get<pt>(k) + get<pt>(k + i); |
| 127 | + auto B = get<pt>(k) - get<pt>(k + i); |
| 128 | + set(k, A); |
| 129 | + set(k + i, B * rt); |
| 130 | + }; |
| 131 | + if(2 * i <= flen) { |
| 132 | + exec_on_roots(i, i, butterfly); |
| 133 | + } else { |
| 134 | + exec_on_roots<vpoint>(i, i, butterfly); |
119 | 135 | } |
120 | | - }); |
| 136 | + } |
| 137 | + } |
| 138 | + } |
| 139 | + }; |
| 140 | + const cvector cvector::roots = []() { |
| 141 | + cvector res(pre_roots); |
| 142 | + for(size_t n = 1; n < res.size(); n *= 2) { |
| 143 | + auto base = std::polar(1., std::numbers::pi / ftype(n)); |
| 144 | + point cur = 1; |
| 145 | + for(size_t k = 0; k < n; k++) { |
| 146 | + if((k & 15) == 0) { |
| 147 | + cur = std::polar(1., std::numbers::pi * ftype(k) / ftype(n)); |
| 148 | + } |
| 149 | + res.set(n + k, cur); |
| 150 | + cur *= base; |
| 151 | + } |
| 152 | + } |
| 153 | + return res; |
| 154 | + }(); |
| 155 | + |
| 156 | + template<typename base> |
| 157 | + struct dft { |
| 158 | + cvector A; |
| 159 | + |
| 160 | + dft(std::vector<base> const& a, size_t n): A(n) { |
| 161 | + for(size_t i = 0; i < std::min(n, a.size()); i++) { |
| 162 | + A.set(i, a[i]); |
| 163 | + } |
| 164 | + if(n) { |
| 165 | + A.fft(); |
121 | 166 | } |
122 | 167 | } |
| 168 | + |
| 169 | + std::vector<base> operator *= (dft const& B) { |
| 170 | + assert(A.size() == B.A.size()); |
| 171 | + size_t n = A.size(); |
| 172 | + if(!n) { |
| 173 | + return std::vector<base>(); |
| 174 | + } |
| 175 | + A.dot(B.A); |
| 176 | + A.ifft(); |
| 177 | + std::vector<base> res(n); |
| 178 | + for(size_t k = 0; k < n; k++) { |
| 179 | + res[k] = A.get(k); |
| 180 | + } |
| 181 | + return res; |
| 182 | + } |
| 183 | + |
| 184 | + auto operator * (dft const& B) const { |
| 185 | + return dft(*this) *= B; |
| 186 | + } |
| 187 | + |
| 188 | + point operator [](int i) const {return A.get(i);} |
123 | 189 | }; |
124 | | - std::array<point, ftvec::pre_roots> ftvec::roots = {}; |
125 | | - std::array<point, ftvec::pre_roots> ftvec::evalp = {}; |
126 | | - std::array<size_t, ftvec::pre_roots> ftvec::eval_args = {}; |
127 | 190 | } |
128 | 191 | #endif // CP_ALGO_MATH_CVECTOR_HPP |
0 commit comments