@@ -15,6 +15,14 @@ namespace cp_algo::math::fft {
15
15
std::vector<int > bitr;// b[2^n + k] = bitreverse(k)
16
16
const ftype pi = acos(-1 );
17
17
bool initiated = 0 ;
18
+ size_t bitreverse (size_t n, size_t k) {
19
+ size_t hn = n / 2 ;
20
+ if (k >= hn) {
21
+ return 2 * bitr[k] + 1 ;
22
+ } else {
23
+ return 2 * bitr[hn + k];
24
+ }
25
+ }
18
26
void init () {
19
27
if (!initiated) {
20
28
w.resize (maxn);
@@ -32,24 +40,34 @@ namespace cp_algo::math::fft {
32
40
}
33
41
}
34
42
35
- void fft (auto &a, int n) {
43
+ void ifft (auto &a, int n) {
36
44
init ();
37
45
if (n == 1 ) {
38
46
return ;
39
47
}
40
- int hn = n / 2 ;
41
- for (int i = 0 ; i < n; i++) {
42
- int ti = 2 * bitr[hn + i % hn] + (i > hn);
43
- if (i < ti) {
44
- std::swap (a[i], a[ti]);
48
+ for (int i = 1 ; i < n; i *= 2 ) {
49
+ for (int j = 0 ; j < n; j += 2 * i) {
50
+ for (int k = j; k < j + i; k++) {
51
+ std::tie (a[k], a[k + i]) = std::pair{
52
+ a[k] + a[k + i] * conj (w[i + k - j]),
53
+ a[k] - a[k + i] * conj (w[i + k - j])
54
+ };
55
+ }
45
56
}
46
57
}
47
- for (int i = 1 ; i < n; i *= 2 ) {
58
+ }
59
+ void fft (auto &a, int n) {
60
+ init ();
61
+ if (n == 1 ) {
62
+ return ;
63
+ }
64
+ for (int i = n / 2 ; i >= 1 ; i /= 2 ) {
48
65
for (int j = 0 ; j < n; j += 2 * i) {
49
66
for (int k = j; k < j + i; k++) {
50
- point t = a[k + i] * w[i + k - j];
51
- a[k + i] = a[k] - t;
52
- a[k] += t;
67
+ std::tie (a[k], a[k + i]) = std::pair{
68
+ a[k] + a[k + i],
69
+ (a[k] - a[k + i]) * w[i + k - j]
70
+ };
53
71
}
54
72
}
55
73
}
@@ -93,8 +111,7 @@ namespace cp_algo::math::fft {
93
111
for (size_t i = 0 ; i < n; i++) {
94
112
A[i] *= B[i];
95
113
}
96
- fft (A, n);
97
- reverse (begin (A) + 1 , end (A));
114
+ ifft (A, n);
98
115
for (size_t i = 0 ; i < n; i++) {
99
116
A[i] /= n;
100
117
}
@@ -126,28 +143,27 @@ namespace cp_algo::math::fft {
126
143
fft (A, n);
127
144
}
128
145
}
129
-
130
- std::vector<base> operator *= (dft const & B) {
146
+
147
+ std::vector<base> operator *= (dft B) {
131
148
assert (A.size () == B.A .size ());
132
149
size_t n = A.size ();
133
150
if (!n) {
134
151
return std::vector<base>();
135
152
}
136
153
std::vector<point> C (n);
137
- for (size_t i = 0 ; 2 * i <= n; i++) {
138
- int x = i;
139
- int y = (n - i) % n;
154
+ for (size_t i = 0 ; 2 * i <= n; i++) {//
155
+ size_t j = (n - i) % n;
156
+ size_t x = bitreverse (n, i);
157
+ size_t y = bitreverse (n, j);
140
158
std::tie (C[x], A[x], C[y], A[y]) = std::make_tuple (
141
159
A[x] * (B[x] + conj (B[y])),
142
160
A[x] * (B[x] - conj (B[y])),
143
161
A[y] * (B[y] + conj (B[x])),
144
162
A[y] * (B[y] - conj (B[x]))
145
163
);
146
164
}
147
- fft (C, n);
148
- fft (A, n);
149
- reverse (begin (C) + 1 , end (C));
150
- reverse (begin (A) + 1 , end (A));
165
+ ifft (C, n);
166
+ ifft (A, n);
151
167
int t = 2 * n;
152
168
std::vector<base> res (n);
153
169
for (size_t i = 0 ; i < n; i++) {
0 commit comments