|
3 | 3 | #include "common.hpp"
|
4 | 4 | #include "modular.hpp"
|
5 | 5 | #include <algorithm>
|
| 6 | +#include <complex> |
6 | 7 | #include <cassert>
|
7 | 8 | #include <vector>
|
8 | 9 | namespace cp_algo::algebra::fft {
|
9 | 10 | using ftype = double;
|
10 |
| - struct point { |
11 |
| - ftype x, y; |
12 |
| - |
13 |
| - ftype real() {return x;} |
14 |
| - ftype imag() {return y;} |
15 |
| - |
16 |
| - point(): x(0), y(0){} |
17 |
| - point(ftype x, ftype y = 0): x(x), y(y){} |
18 |
| - |
19 |
| - static point polar(ftype rho, ftype ang) { |
20 |
| - return point{rho * cos(ang), rho * sin(ang)}; |
21 |
| - } |
22 |
| - |
23 |
| - point conj() const { |
24 |
| - return {x, -y}; |
25 |
| - } |
26 |
| - |
27 |
| - point operator +=(const point &t) {x += t.x, y += t.y; return *this;} |
28 |
| - point operator +(const point &t) const {return point(*this) += t;} |
29 |
| - point operator -(const point &t) const {return {x - t.x, y - t.y};} |
30 |
| - point operator *(const point &t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};} |
31 |
| - }; |
| 11 | + using point = std::complex<ftype>; |
32 | 12 |
|
33 |
| - point w[maxn]; // w[2^n + k] = exp(pi * k / (2^n)) |
34 |
| - int bitr[maxn];// b[2^n + k] = bitreverse(k) |
| 13 | + std::vector<point> w; // w[2^n + k] = exp(pi * k / (2^n)) |
| 14 | + std::vector<int> bitr;// b[2^n + k] = bitreverse(k) |
35 | 15 | const ftype pi = acos(-1);
|
36 | 16 | bool initiated = 0;
|
37 | 17 | void init() {
|
38 | 18 | if(!initiated) {
|
| 19 | + w.resize(maxn); |
| 20 | + bitr.resize(maxn); |
39 | 21 | for(int i = 1; i < maxn; i *= 2) {
|
40 | 22 | int ti = i / 2;
|
41 | 23 | for(int j = 0; j < i; j++) {
|
42 |
| - w[i + j] = point::polar(ftype(1), pi * j / i); |
| 24 | + w[i + j] = std::polar(ftype(1), pi * j / i); |
43 | 25 | if(ti) {
|
44 | 26 | bitr[i + j] = 2 * bitr[ti + j % ti] + (j >= ti);
|
45 | 27 | }
|
@@ -113,8 +95,8 @@ namespace cp_algo::algebra::fft {
|
113 | 95 | }
|
114 | 96 | std::vector<point> C(n), D(n);
|
115 | 97 | for(size_t i = 0; i < n; i++) {
|
116 |
| - C[i] = A[i] * (B[i] + B[(n - i) % n].conj()); |
117 |
| - D[i] = A[i] * (B[i] - B[(n - i) % n].conj()); |
| 98 | + C[i] = A[i] * (B[i] + conj(B[(n - i) % n])); |
| 99 | + D[i] = A[i] * (B[i] - conj(B[(n - i) % n])); |
118 | 100 | }
|
119 | 101 | fft(C, n);
|
120 | 102 | fft(D, n);
|
|
0 commit comments