@@ -263,6 +263,21 @@ namespace cp_algo::math::fft {
263
263
point operator [](int i) const {return A.get (i);}
264
264
};
265
265
266
+ void mul_slow (auto &a, auto const & b, size_t k) {
267
+ if (empty (a) || empty (b)) {
268
+ a.clear ();
269
+ } else {
270
+ int n = std::min (k, size (a));
271
+ int m = std::min (k, size (b));
272
+ a.resize (k);
273
+ for (int j = k - 1 ; j >= 0 ; j--) {
274
+ a[j] *= b[0 ];
275
+ for (int i = std::max (j - n, 0 ) + 1 ; i < std::min (j + 1 , m); i++) {
276
+ a[j] += a[j - i] * b[i];
277
+ }
278
+ }
279
+ }
280
+ }
266
281
size_t com_size (size_t as, size_t bs) {
267
282
if (!as || !bs) {
268
283
return 0 ;
@@ -271,18 +286,16 @@ namespace cp_algo::math::fft {
271
286
}
272
287
void mul_truncate (auto &a, auto const & b, size_t k) {
273
288
using base = std::decay_t <decltype (a[0 ])>;
274
- if (size (b) == 0 ) {
275
- a. clear ( );
289
+ if (std::min ({k, size (a), size (b)}) < 64 ) {
290
+ mul_slow (a, b, k );
276
291
return ;
277
292
}
278
293
auto n = std::max (flen, std::bit_ceil (
279
294
std::min (k, size (a)) + std::min (k, size (b)) - 1
280
295
) / 2 );
281
- auto A = dft<base>(std::views::take (a, k), n);
282
- if (size (a) != k) {
283
- a.resize (k);
284
- }
285
- if (a == b) {
296
+ a.resize (k);
297
+ auto A = dft<base>(a, n);
298
+ if (&a == &b) {
286
299
A.mul (dft<base>(A), a, k);
287
300
} else {
288
301
A.mul (dft<base>(std::views::take (b, k), n), a, k);
0 commit comments