7
7
#include < cassert>
8
8
#include < vector>
9
9
#include < bit>
10
- namespace cp_algo ::math::fft {
11
- const auto bitr = [](){
12
- std::vector<size_t > bitr (maxn);
13
- for (size_t n = 2 ; n < maxn; n *= 2 ) {
14
- for (size_t k = 0 ; k < n; k++) {
15
- bitr[n + k] = bitr[n + k / 2 ] / 2 + (k & 1 ) * (n / 2 );
16
- }
17
- }
18
- return bitr;
19
- }();
20
- size_t bitreverse (size_t n, size_t k) {
21
- size_t hn = n / 2 ;
22
- if (k >= hn && n > 1 ) {
23
- return 2 * bitr[k] + 1 ;
24
- } else {
25
- return 2 * bitr[hn + k];
26
- }
27
- }
28
10
11
+ namespace cp_algo ::math::fft {
29
12
using ftype = double ;
30
13
static constexpr size_t bytes = 32 ;
31
14
static constexpr size_t flen = bytes / sizeof (ftype);
32
15
using point = std::complex<ftype>;
33
16
using vftype [[gnu::vector_size(bytes)]] = ftype;
34
17
using vpoint = std::complex<vftype>;
35
18
36
- constexpr vftype to_vec (ftype x) {
37
- return vftype{} + x;
19
+ #define WITH_IV (...) \
20
+ [&]<size_t ... i>(std::index_sequence<i...>) { \
21
+ return __VA_ARGS__; \
22
+ }(std::make_index_sequence<flen>());
23
+
24
+ template <typename ft>
25
+ constexpr ft to_ft (auto x) {
26
+ return ft{} + x;
38
27
}
39
- constexpr vpoint to_vec (point r) {
40
- return {to_vec (r.real ()), to_vec (r.imag ())};
28
+ template <typename pt>
29
+ constexpr pt to_pt (point r) {
30
+ using ft = std::conditional_t <std::is_same_v<point, pt>, ftype, vftype>;
31
+ return {to_ft<ft>(r.real ()), to_ft<ft>(r.imag ())};
41
32
}
42
33
struct cvector {
34
+ static constexpr size_t pre_roots = 1 << 17 ;
43
35
std::vector<vftype> x, y;
44
36
cvector () {}
45
37
cvector (size_t n) {
@@ -84,54 +76,74 @@ namespace cp_algo::math::fft {
84
76
}
85
77
}
86
78
static const cvector roots;
79
+ template <class pt = point>
80
+ static pt root (size_t n, size_t k) {
81
+ if (n < pre_roots) {
82
+ return cvector::roots.get <pt>(n + k);
83
+ } else {
84
+ auto arg = std::numbers::pi / n;
85
+ if constexpr (std::is_same_v<pt, point>) {
86
+ return {cos (k * arg), sin (k * arg)};
87
+ } else {
88
+ return WITH_IV (pt{vftype{cos ((k + i) * arg)...},
89
+ vftype{sin ((k + i) * arg)...}});
90
+ }
91
+ }
92
+ }
93
+ template <class pt = point>
94
+ static void exec_on_roots (size_t n, size_t m, auto &&callback) {
95
+ size_t step = sizeof (pt) / sizeof (point);
96
+ pt cur;
97
+ pt arg = to_pt<pt>(root<point>(n, step));
98
+ for (size_t i = 0 ; i < m; i += step) {
99
+ cur = (i & 63 ) == 0 || n < pre_roots ? root<pt>(n, i) : cur * arg;
100
+ callback (i, cur);
101
+ }
102
+ }
87
103
88
104
void ifft () {
89
105
size_t n = size ();
90
106
for (size_t i = 1 ; i < n; i *= 2 ) {
91
107
for (size_t j = 0 ; j < n; j += 2 * i) {
92
- auto butterfly = [&]<class pt >(pt) {
93
- size_t step = sizeof (pt) / sizeof (point);
94
- for (size_t k = j; k < j + i; k += step) {
95
- auto T = get<pt>(k + i) * conj (roots.get <pt>(i + k - j));
96
- set (k + i, get<pt>(k) - T);
97
- set (k, get<pt>(k) + T);
98
- }
108
+ auto butterfly = [&]<class pt >(size_t k, pt rt) {
109
+ k += j;
110
+ auto t = get<pt>(k + i) * conj (rt);
111
+ set (k + i, get<pt>(k) - t);
112
+ set (k, get<pt>(k) + t);
99
113
};
100
114
if (2 * i <= flen) {
101
- butterfly (point{} );
115
+ exec_on_roots (i, i, butterfly );
102
116
} else {
103
- butterfly ( vpoint{} );
117
+ exec_on_roots< vpoint>(i, i, butterfly );
104
118
}
105
119
}
106
120
}
107
121
for (size_t k = 0 ; k < n; k += flen) {
108
- set (k, get<vpoint>(k) /= to_vec (n));
122
+ set (k, get<vpoint>(k) /= to_pt<vpoint> (n));
109
123
}
110
124
}
111
125
void fft () {
112
126
size_t n = size ();
113
127
for (size_t i = n / 2 ; i >= 1 ; i /= 2 ) {
114
128
for (size_t j = 0 ; j < n; j += 2 * i) {
115
- auto butterfly = [&]<class pt >(pt) {
116
- size_t step = sizeof (pt) / sizeof (point);
117
- for (size_t k = j; k < j + i; k += step) {
118
- auto A = get<pt>(k) + get<pt>(k + i);
119
- auto B = get<pt>(k) - get<pt>(k + i);
120
- set (k, A);
121
- set (k + i, B * roots.get <pt>(i + k - j));
122
- }
129
+ auto butterfly = [&]<class pt >(size_t k, pt rt) {
130
+ k += j;
131
+ auto A = get<pt>(k) + get<pt>(k + i);
132
+ auto B = get<pt>(k) - get<pt>(k + i);
133
+ set (k, A);
134
+ set (k + i, B * rt);
123
135
};
124
136
if (2 * i <= flen) {
125
- butterfly (point{} );
137
+ exec_on_roots (i, i, butterfly );
126
138
} else {
127
- butterfly ( vpoint{} );
139
+ exec_on_roots< vpoint>(i, i, butterfly );
128
140
}
129
141
}
130
142
}
131
143
}
132
144
};
133
145
const cvector cvector::roots = []() {
134
- cvector res (2 * maxn );
146
+ cvector res (pre_roots );
135
147
for (size_t n = 1 ; n < res.size (); n *= 2 ) {
136
148
auto base = std::polar (1 ., std::numbers::pi / n);
137
149
point cur = 1 ;
@@ -145,15 +157,6 @@ namespace cp_algo::math::fft {
145
157
}
146
158
return res;
147
159
}();
148
- point root (size_t n, size_t k) {
149
- if (n < maxn) {
150
- return cvector::roots.get (n + k);
151
- } else if (k % 2 == 0 ) {
152
- return root (n / 2 , k / 2 );
153
- } else {
154
- return std::polar (1 ., std::numbers::pi * k / n);
155
- }
156
- }
157
160
158
161
template <typename base>
159
162
void mul_slow (std::vector<base> &a, const std::vector<base> &b) {
@@ -171,7 +174,7 @@ namespace cp_algo::math::fft {
171
174
}
172
175
}
173
176
}
174
-
177
+
175
178
template <typename base>
176
179
struct dft {
177
180
cvector A;
@@ -184,7 +187,7 @@ namespace cp_algo::math::fft {
184
187
A.fft ();
185
188
}
186
189
}
187
-
190
+
188
191
std::vector<base> operator *= (dft const & B) {
189
192
assert (A.size () == B.A .size ());
190
193
size_t n = A.size ();
@@ -203,26 +206,17 @@ namespace cp_algo::math::fft {
203
206
auto operator * (dft const & B) const {
204
207
return dft (*this ) *= B;
205
208
}
206
-
209
+
207
210
point operator [](int i) const {return A.get (i);}
208
211
};
209
212
210
213
template <modint_type base>
211
214
struct dft <base> {
212
215
static constexpr int split = 1 << 15 ;
213
216
cvector A, B;
214
-
215
- void exec_on_roots (size_t n, size_t m, auto &&callback) {
216
- point cur = 1 ;
217
- point step = root (n, 1 );
218
- for (size_t i = 0 ; i < m; i++) {
219
- callback (i, cur);
220
- cur = (i & 15 ) == 0 || 2 * n < maxn ? root (n, i + 1 ) : cur * step;
221
- }
222
- }
223
217
224
218
dft (std::vector<base> const & a, size_t n): A(n), B(n) {
225
- exec_on_roots (2 * n, size (a), [&](size_t i, point rt) {
219
+ cvector:: exec_on_roots (2 * n, size (a), [&](size_t i, point rt) {
226
220
A.set (i % n, A.get (i % n) + ftype (a[i].rem () % split) * rt);
227
221
B.set (i % n, B.get (i % n) + ftype (a[i].rem () / split) * rt);
228
222
@@ -249,7 +243,7 @@ namespace cp_algo::math::fft {
249
243
B.ifft ();
250
244
C.ifft ();
251
245
std::vector<base> res (2 * n);
252
- exec_on_roots (2 * n, n, [&](size_t i, point rt) {
246
+ cvector:: exec_on_roots (2 * n, n, [&](size_t i, point rt) {
253
247
rt = conj (rt);
254
248
auto Ai = A.get (i) * rt;
255
249
auto Bi = B.get (i) * rt;
0 commit comments