@@ -70,6 +70,10 @@ namespace cp_algo::math::fft {
70
70
return {x[k / flen], y[k / flen]};
71
71
}
72
72
}
73
+ vpoint vget (size_t k) const {
74
+ return get<vpoint>(k);
75
+ }
76
+
73
77
size_t size () const {
74
78
return flen * std::size (x);
75
79
}
@@ -127,8 +131,8 @@ namespace cp_algo::math::fft {
127
131
}
128
132
};
129
133
const cvector cvector::roots = []() {
130
- cvector res (maxn);
131
- for (size_t n = 1 ; n < maxn ; n *= 2 ) {
134
+ cvector res (2 * maxn);
135
+ for (size_t n = 1 ; n < res. size () ; n *= 2 ) {
132
136
auto base = std::polar (1 ., std::numbers::pi / n);
133
137
point cur = 1 ;
134
138
for (size_t k = 0 ; k < n; k++) {
@@ -197,50 +201,52 @@ namespace cp_algo::math::fft {
197
201
template <modint_type base>
198
202
struct dft <base> {
199
203
static constexpr int split = 1 << 15 ;
200
- cvector A;
204
+ cvector A, B ;
201
205
202
- dft (std::vector<base> const & a, size_t n): A(n) {
203
- for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
204
- A.set (i, point{
205
- ftype (a[i].rem () % split),
206
- ftype (a[i].rem () / split)
207
- });
206
+ dft (std::vector<base> const & a, size_t n): A(n), B(n) {
207
+ for (size_t i = 0 ; i < size (a); i++) {
208
+ A.set (i % n, A.get (i % n) + ftype (a[i].rem () % split) * cvector::roots.get (2 * n + i));
209
+ B.set (i % n, B.get (i % n) + ftype (a[i].rem () / split) * cvector::roots.get (2 * n + i));
208
210
}
209
211
if (n) {
210
212
A.fft ();
213
+ B.fft ();
211
214
}
212
215
}
213
216
214
- std::vector<base> mul (auto &&B ) {
215
- assert (A.size () == B .size ());
217
+ std::vector<base> mul (auto &&C, auto &&D ) {
218
+ assert (A.size () == C .size ());
216
219
size_t n = A.size ();
217
220
if (!n) {
218
221
return std::vector<base>();
219
222
}
220
- for (size_t i = 0 ; 2 * i <= n; i++) {
221
- size_t j = (n - i) % n;
222
- size_t x = bitreverse (n, i);
223
- size_t y = bitreverse (n, j);
224
- auto Ax = A.get (x), Bx = B.get (x);
225
- auto Ay = A.get (y), By = B.get (y);
226
- B.set (x, Ax * (Bx + conj (By)));
227
- A.set (x, Ax * (Bx - conj (By)));
228
- B.set (y, Ay * (By + conj (Bx)));
229
- A.set (y, Ay * (By - conj (Bx)));
223
+ for (size_t i = 0 ; i < n; i += flen) {
224
+ auto tmp = A.vget (i) * D.vget (i) + B.vget (i) * C.vget (i);
225
+ A.set (i, A.vget (i) * C.vget (i));
226
+ B.set (i, B.vget (i) * D.vget (i));
227
+ C.set (i, tmp);
230
228
}
231
229
A.ifft ();
232
230
B.ifft ();
233
- std::vector<base> res (n);
231
+ C.ifft ();
232
+ std::vector<base> res (2 * n);
234
233
for (size_t i = 0 ; i < n; i++) {
235
- base A0 = llround (B.get (i).real ()) / 2 ;
236
- base A1 = llround (B.get (i).imag () + A.get (i).imag ()) / 2 ;
237
- base A2 = llround (A.get (i).real ()) / 2 ;
238
- res[i] = A0 + A1 * split - A2 * split * split;
234
+ auto Ai = A.get (i) * conj (cvector::roots.get (2 * n + i));
235
+ auto Bi = B.get (i) * conj (cvector::roots.get (2 * n + i));
236
+ auto Ci = C.get (i) * conj (cvector::roots.get (2 * n + i));
237
+ base A0 = llround (real (Ai));
238
+ base A1 = llround (real (Ci));
239
+ base A2 = llround (real (Bi));
240
+ res[i] = A0 + A1 * split + A2 * split * split;
241
+ base B0 = llround (imag (Ai));
242
+ base B1 = llround (imag (Ci));
243
+ base B2 = llround (imag (Bi));
244
+ res[n + i] = B0 + B1 * split + B2 * split * split;
239
245
}
240
246
return res;
241
247
}
242
248
std::vector<base> operator *= (auto &&B) {
243
- return mul (B.A );
249
+ return mul (B.A , B. B );
244
250
}
245
251
246
252
auto operator * (dft const & B) const {
@@ -254,12 +260,12 @@ namespace cp_algo::math::fft {
254
260
if (!as || !bs) {
255
261
return 0 ;
256
262
}
257
- return std::bit_ceil (as + bs - 1 );
263
+ return std::max (flen, std:: bit_ceil (as + bs - 1 ) / 2 );
258
264
}
259
265
260
266
template <typename base>
261
267
void mul (std::vector<base> &a, std::vector<base> const & b) {
262
- if (std::min (a.size (), b.size ()) < magic ) {
268
+ if (std::min (a.size (), b.size ()) < 1 ) {
263
269
mul_slow (a, b);
264
270
return ;
265
271
}
@@ -273,7 +279,7 @@ namespace cp_algo::math::fft {
273
279
}
274
280
template <typename base>
275
281
void circular_mul (std::vector<base> &a, std::vector<base> const & b) {
276
- auto n = std::bit_ceil (a.size ());
282
+ auto n = std::max (flen, std:: bit_ceil (max ( a.size (), b. size ())) / 2 );
277
283
auto A = dft<base>(a, n);
278
284
if (a == b) {
279
285
a = A *= dft<base>(A);
0 commit comments