8
8
#include < vector>
9
9
#include < bit>
10
10
namespace cp_algo ::math::fft {
11
- using ftype = double ;
12
- using point = std::complex<ftype>;
13
-
14
- std::vector<point> w; // w[2^n + k] = exp(pi * k / (2^n))
15
- std::vector<int > bitr;// b[2^n + k] = bitreverse(k)
16
- const ftype pi = acos(-1 );
17
- bool initiated = 0 ;
11
+ const auto bitr = [](){
12
+ std::vector<size_t > bitr (maxn);
13
+ for (size_t n = 2 ; n < maxn; n *= 2 ) {
14
+ for (size_t k = 0 ; k < n; k++) {
15
+ bitr[n + k] = bitr[n + k / 2 ] / 2 + (k & 1 ) * (n / 2 );
16
+ }
17
+ }
18
+ return bitr;
19
+ }();
18
20
size_t bitreverse (size_t n, size_t k) {
19
21
size_t hn = n / 2 ;
20
22
if (k >= hn && n > 1 ) {
@@ -23,62 +25,123 @@ namespace cp_algo::math::fft {
23
25
return 2 * bitr[hn + k];
24
26
}
25
27
}
26
- void init () {
27
- if (!initiated) {
28
- w.resize (maxn);
29
- bitr.resize (maxn);
30
- for (int i = 1 ; i < maxn; i *= 2 ) {
31
- int ti = i / 2 ;
32
- ftype arg = pi / i;
33
- point base = std::polar (1 ., arg);
34
- point cur = 1 .;
35
- for (int j = 0 ; j < i; j++) {
36
- if ((j & 15 ) == 0 ) {
37
- cur = std::polar (1 ., j * arg);
38
- }
39
- w[i + j] = cur;
40
- cur *= base;
41
- if (ti) {
42
- bitr[i + j] = 2 * bitr[ti + j % ti] + (j >= ti);
43
- }
44
- }
28
+
29
+ using ftype = double ;
30
+ static constexpr size_t bytes = 32 ;
31
+ static constexpr size_t flen = bytes / sizeof (ftype);
32
+ using point = std::complex<ftype>;
33
+ using vftype [[gnu::vector_size(bytes)]] = ftype;
34
+ using vpoint = std::complex<vftype>;
35
+
36
+ constexpr vftype to_vec (ftype x) {
37
+ return vftype{} + x;
38
+ }
39
+ constexpr vpoint to_vec (point r) {
40
+ return {to_vec (r.real ()), to_vec (r.imag ())};
41
+ }
42
+ struct cvector {
43
+ std::vector<vftype> x, y;
44
+ cvector () {}
45
+ cvector (size_t n) {
46
+ resize (n);
47
+ }
48
+ void resize (size_t n) {
49
+ n = std::bit_ceil (std::max<size_t >(n, 4 ));
50
+ if (size () != n) {
51
+ x.resize (n / flen);
52
+ y.resize (n / flen);
45
53
}
46
- initiated = 1 ;
47
54
}
48
- }
49
-
50
- void ifft (auto &a, int n) {
51
- init ();
52
- if (n == 1 ) {
53
- return ;
55
+ template <class pt = point>
56
+ void set (size_t k, pt t) {
57
+ if constexpr (std::is_same_v<pt, point>) {
58
+ x[k / flen][k % flen] = real (t);
59
+ y[k / flen][k % flen] = imag (t);
60
+ } else {
61
+ x[k / flen] = real (t);
62
+ y[k / flen] = imag (t);
63
+ }
64
+ }
65
+ template <class pt = point>
66
+ pt get (size_t k) const {
67
+ if constexpr (std::is_same_v<pt, point>) {
68
+ return {x[k / flen][k % flen], y[k / flen][k % flen]};
69
+ } else {
70
+ return {x[k / flen], y[k / flen]};
71
+ }
72
+ }
73
+ size_t size () const {
74
+ return flen * std::size (x);
54
75
}
55
- for (int i = 1 ; i < n; i *= 2 ) {
56
- for (int j = 0 ; j < n; j += 2 * i) {
57
- for (int k = j; k < j + i; k++) {
58
- std::tie (a[k], a[k + i]) = std::pair{
59
- a[k] + a[k + i] * conj (w[i + k - j]),
60
- a[k] - a[k + i] * conj (w[i + k - j])
76
+ void dot (cvector const & t) {
77
+ size_t n = size ();
78
+ for (size_t k = 0 ; k < n; k += flen) {
79
+ set (k, get<vpoint>(k) * t.get <vpoint>(k));
80
+ }
81
+ }
82
+ static const cvector roots;
83
+
84
+ void ifft () {
85
+ size_t n = size ();
86
+ for (size_t i = 1 ; i < n; i *= 2 ) {
87
+ for (size_t j = 0 ; j < n; j += 2 * i) {
88
+ auto butterfly = [&]<class pt >(pt) {
89
+ size_t step = sizeof (pt) / sizeof (point);
90
+ for (size_t k = j; k < j + i; k += step) {
91
+ auto T = get<pt>(k + i) * conj (roots.get <pt>(i + k - j));
92
+ set (k + i, get<pt>(k) - T);
93
+ set (k, get<pt>(k) + T);
94
+ }
61
95
};
96
+ if (2 * i <= flen) {
97
+ butterfly (point{});
98
+ } else {
99
+ butterfly (vpoint{});
100
+ }
62
101
}
63
102
}
103
+ for (size_t k = 0 ; k < n; k += flen) {
104
+ set (k, get<vpoint>(k) /= to_vec (n));
105
+ }
64
106
}
65
- }
66
- void fft ( auto &a, int n) {
67
- init ();
68
- if (n == 1 ) {
69
- return ;
70
- }
71
- for (int i = n / 2 ; i >= 1 ; i /= 2 ) {
72
- for ( int j = 0 ; j < n; j += 2 * i) {
73
- for ( int k = j; k < j + i; k++) {
74
- std::tie (a[k], a[k + i]) = std::pair{
75
- a[k] + a[ k + i],
76
- (a[k] - a[k + i]) * w[i + k - j]
107
+ void fft () {
108
+ size_t n = size ();
109
+ for ( size_t i = n / 2 ; i >= 1 ; i /= 2 ) {
110
+ for ( size_t j = 0 ; j < n; j += 2 * i ) {
111
+ auto butterfly = [&]< class pt >(pt) {
112
+ size_t step = sizeof (pt) / sizeof (point);
113
+ for (size_t k = j; k < j + i; k += step ) {
114
+ auto A = get<pt>(k) + get<pt>(k + i);
115
+ auto B = get<pt>(k) - get<pt>(k + i);
116
+ set (k, A);
117
+ set ( k + i, B * roots. get <pt>(i + k - j));
118
+ }
77
119
};
120
+ if (2 * i <= flen) {
121
+ butterfly (point{});
122
+ } else {
123
+ butterfly (vpoint{});
124
+ }
78
125
}
79
126
}
80
127
}
81
- }
128
+ };
129
+ const cvector cvector::roots = []() {
130
+ cvector res (maxn);
131
+ for (size_t n = 1 ; n < maxn; n *= 2 ) {
132
+ auto base = std::polar (1 ., std::numbers::pi / n);
133
+ point cur = 1 ;
134
+ for (size_t k = 0 ; k < n; k++) {
135
+ if ((k & 15 ) == 0 ) {
136
+ cur = std::polar (1 ., std::numbers::pi * k / n);
137
+ }
138
+ res.set (n + k, cur);
139
+ cur *= base;
140
+ }
141
+ }
142
+ return res;
143
+ }();
144
+
82
145
template <typename base>
83
146
void mul_slow (std::vector<base> &a, const std::vector<base> &b) {
84
147
if (a.empty () || b.empty ()) {
@@ -98,14 +161,14 @@ namespace cp_algo::math::fft {
98
161
99
162
template <typename base>
100
163
struct dft {
101
- std::vector<point> A;
164
+ cvector A;
102
165
103
166
dft (std::vector<base> const & a, size_t n): A(n) {
104
167
for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
105
- A[i] = a[i];
168
+ A. set (i, a[i]) ;
106
169
}
107
170
if (n) {
108
- fft (A, n );
171
+ A. fft ();
109
172
}
110
173
}
111
174
@@ -115,39 +178,33 @@ namespace cp_algo::math::fft {
115
178
if (!n) {
116
179
return std::vector<base>();
117
180
}
118
- for (size_t i = 0 ; i < n; i++) {
119
- A[i] *= B[i];
120
- }
121
- ifft (A, n);
122
- for (size_t i = 0 ; i < n; i++) {
123
- A[i] /= n;
124
- }
125
- if constexpr (std::is_same_v<base, point>) {
126
- return A;
127
- } else {
128
- return {begin (A), end (A)};
181
+ A.dot (B.A );
182
+ A.ifft ();
183
+ std::vector<base> res (n);
184
+ for (size_t k = 0 ; k < n; k++) {
185
+ res[k] = A.get (k);
129
186
}
187
+ return res;
130
188
}
131
189
132
190
auto operator * (dft const & B) const {
133
191
return dft (*this ) *= B;
134
192
}
135
193
136
- point& operator [](int i) {return A[i];}
137
- point operator [](int i) const {return A[i];}
194
+ point operator [](int i) const {return A.get (i);}
138
195
};
139
196
140
197
template <modint_type base>
141
198
struct dft <base> {
142
199
static constexpr int split = 1 << 15 ;
143
- std::vector<point> A;
200
+ cvector A;
144
201
145
202
dft (std::vector<base> const & a, size_t n): A(n) {
146
203
for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
147
- A[i] = point ( a[i].rem () % split, a[i].rem () / split);
204
+ A. set (i, point{ a[i].rem () % split, a[i].rem () / split} );
148
205
}
149
206
if (n) {
150
- fft (A, n );
207
+ A. fft ();
151
208
}
152
209
}
153
210
@@ -157,26 +214,25 @@ namespace cp_algo::math::fft {
157
214
if (!n) {
158
215
return std::vector<base>();
159
216
}
160
- std::vector<point> C (n);
217
+ cvector C (n);
161
218
for (size_t i = 0 ; 2 * i <= n; i++) {
162
219
size_t j = (n - i) % n;
163
220
size_t x = bitreverse (n, i);
164
221
size_t y = bitreverse (n, j);
165
- std::tie (C[x], A[x], C[y], A[y]) = std::make_tuple (
166
- A[x] * (B[x] + conj (B[y])),
167
- A[x] * (B[x] - conj (B[y])),
168
- A[y] * (B[y] + conj (B[x])),
169
- A[y] * (B[y] - conj (B[x]))
170
- );
171
- }
172
- ifft (C, n);
173
- ifft (A, n);
174
- int t = 2 * n;
222
+ auto Ax = A.get (x), Bx = B[x];
223
+ auto Ay = A.get (y), By = B[y];
224
+ C.set (x, Ax * (Bx + conj (By)));
225
+ A.set (x, Ax * (Bx - conj (By)));
226
+ C.set (y, Ay * (By + conj (Bx)));
227
+ A.set (y, Ay * (By - conj (Bx)));
228
+ }
229
+ A.ifft ();
230
+ C.ifft ();
175
231
std::vector<base> res (n);
176
232
for (size_t i = 0 ; i < n; i++) {
177
- base A0 = llround (C[i] .real () / t) ;
178
- base A1 = llround (C[i] .imag () / t + A[i] .imag () / t) ;
179
- base A2 = llround (A[i] .real () / t) ;
233
+ base A0 = llround (C. get (i) .real ()) / 2 ;
234
+ base A1 = llround (C. get (i) .imag () + A. get (i) .imag ()) / 2 ;
235
+ base A2 = llround (A. get (i) .real ()) / 2 ;
180
236
res[i] = A0 + A1 * split - A2 * split * split;
181
237
}
182
238
return res;
@@ -186,8 +242,7 @@ namespace cp_algo::math::fft {
186
242
return dft (*this ) *= B;
187
243
}
188
244
189
- point& operator [](int i) {return A[i];}
190
- point operator [](int i) const {return A[i];}
245
+ point operator [](int i) const {return A.get (i);}
191
246
};
192
247
193
248
size_t com_size (size_t as, size_t bs) {
0 commit comments