@@ -46,7 +46,7 @@ namespace cp_algo::math::fft {
46
46
resize (n);
47
47
}
48
48
void resize (size_t n) {
49
- n = std::bit_ceil (std::max<size_t >(n, flen ));
49
+ n = std::bit_ceil (std::max<size_t >(n, 4 ));
50
50
if (size () != n) {
51
51
x.resize (n / flen);
52
52
y.resize (n / flen);
@@ -201,45 +201,50 @@ namespace cp_algo::math::fft {
201
201
202
202
dft (std::vector<base> const & a, size_t n): A(n) {
203
203
for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
204
- A.set (i, point{a[i].rem () % split, a[i].rem () / split});
204
+ A.set (i, point{
205
+ ftype (a[i].rem () % split),
206
+ ftype (a[i].rem () / split)
207
+ });
205
208
}
206
209
if (n) {
207
210
A.fft ();
208
211
}
209
212
}
210
213
211
- std::vector<base> operator *= (dft const & B) {
212
- assert (A.size () == B.A . size ());
214
+ std::vector<base> mul ( auto && B) {
215
+ assert (A.size () == B.size ());
213
216
size_t n = A.size ();
214
217
if (!n) {
215
218
return std::vector<base>();
216
219
}
217
- cvector C (n);
218
220
for (size_t i = 0 ; 2 * i <= n; i++) {
219
221
size_t j = (n - i) % n;
220
222
size_t x = bitreverse (n, i);
221
223
size_t y = bitreverse (n, j);
222
- auto Ax = A.get (x), Bx = B[x] ;
223
- auto Ay = A.get (y), By = B[y] ;
224
- C .set (x, Ax * (Bx + conj (By)));
224
+ auto Ax = A.get (x), Bx = B. get (x) ;
225
+ auto Ay = A.get (y), By = B. get (y) ;
226
+ B .set (x, Ax * (Bx + conj (By)));
225
227
A.set (x, Ax * (Bx - conj (By)));
226
- C .set (y, Ay * (By + conj (Bx)));
228
+ B .set (y, Ay * (By + conj (Bx)));
227
229
A.set (y, Ay * (By - conj (Bx)));
228
230
}
229
231
A.ifft ();
230
- C .ifft ();
232
+ B .ifft ();
231
233
std::vector<base> res (n);
232
234
for (size_t i = 0 ; i < n; i++) {
233
- base A0 = llround (C .get (i).real ()) / 2 ;
234
- base A1 = llround (C .get (i).imag () + A.get (i).imag ()) / 2 ;
235
+ base A0 = llround (B .get (i).real ()) / 2 ;
236
+ base A1 = llround (B .get (i).imag () + A.get (i).imag ()) / 2 ;
235
237
base A2 = llround (A.get (i).real ()) / 2 ;
236
238
res[i] = A0 + A1 * split - A2 * split * split;
237
239
}
238
240
return res;
239
241
}
242
+ std::vector<base> operator *= (auto &&B) {
243
+ return mul (B.A );
244
+ }
240
245
241
246
auto operator * (dft const & B) const {
242
- return dft (*this ) *= B ;
247
+ return dft (*this ) *= dft (B) ;
243
248
}
244
249
245
250
point operator [](int i) const {return A.get (i);}
@@ -261,7 +266,7 @@ namespace cp_algo::math::fft {
261
266
auto n = com_size (a.size (), b.size ());
262
267
auto A = dft<base>(a, n);
263
268
if (a == b) {
264
- a = A *= A ;
269
+ a = A *= dft<base>(A) ;
265
270
} else {
266
271
a = A *= dft<base>(b, n);
267
272
}
@@ -271,7 +276,7 @@ namespace cp_algo::math::fft {
271
276
auto n = std::bit_ceil (a.size ());
272
277
auto A = dft<base>(a, n);
273
278
if (a == b) {
274
- a = A *= A ;
279
+ a = A *= dft<base>(A) ;
275
280
} else {
276
281
a = A *= dft<base>(b, n);
277
282
}
0 commit comments