5
5
#include < algorithm>
6
6
#include < complex>
7
7
#include < cassert>
8
+ #include < ranges>
8
9
#include < vector>
9
10
#include < bit>
10
11
@@ -33,16 +34,10 @@ namespace cp_algo::math::fft {
33
34
struct cvector {
34
35
static constexpr size_t pre_roots = 1 << 17 ;
35
36
std::vector<vftype> x, y;
36
- cvector () {}
37
37
cvector (size_t n) {
38
- resize (n);
39
- }
40
- void resize (size_t n) {
41
- n = std::bit_ceil (std::max<size_t >(n, flen));
42
- if (size () != n) {
43
- x.resize (n / flen);
44
- y.resize (n / flen);
45
- }
38
+ n = std::max (flen, std::bit_ceil (n));
39
+ x.resize (n / flen);
40
+ y.resize (n / flen);
46
41
}
47
42
template <class pt = point>
48
43
void set (size_t k, pt t) {
@@ -162,23 +157,6 @@ namespace cp_algo::math::fft {
162
157
return res;
163
158
}();
164
159
165
- template <typename base>
166
- void mul_slow (std::vector<base> &a, const std::vector<base> &b) {
167
- if (a.empty () || b.empty ()) {
168
- a.clear ();
169
- } else {
170
- int n = a.size ();
171
- int m = b.size ();
172
- a.resize (n + m - 1 );
173
- for (int k = n + m - 2 ; k >= 0 ; k--) {
174
- a[k] *= b[0 ];
175
- for (int j = std::max (k - n + 1 , 1 ); j < std::min (k + 1 , m); j++) {
176
- a[k] += a[k - j] * b[j];
177
- }
178
- }
179
- }
180
- }
181
-
182
160
template <typename base>
183
161
struct dft {
184
162
cvector A;
@@ -219,7 +197,7 @@ namespace cp_algo::math::fft {
219
197
int split;
220
198
cvector A, B;
221
199
222
- dft (std::vector<base> const & a, size_t n): A(n), B(n) {
200
+ dft (auto const & a, size_t n): A(n), B(n) {
223
201
split = std::sqrt (base::mod ());
224
202
cvector::exec_on_roots (2 * n, size (a), [&](size_t i, point rt) {
225
203
size_t ti = std::min (i, i - n);
@@ -233,7 +211,7 @@ namespace cp_algo::math::fft {
233
211
}
234
212
}
235
213
236
- void mul (auto &&C, auto &&D, auto &res) {
214
+ void mul (auto &&C, auto &&D, auto &res, size_t k ) {
237
215
assert (A.size () == C.size ());
238
216
size_t n = A.size ();
239
217
if (!n) {
@@ -249,9 +227,8 @@ namespace cp_algo::math::fft {
249
227
A.ifft ();
250
228
B.ifft ();
251
229
C.ifft ();
252
- res.resize (2 * n);
253
230
auto splitsplit = (base (split) * split).rem ();
254
- cvector::exec_on_roots (2 * n, n , [&](size_t i, point rt) {
231
+ cvector::exec_on_roots (2 * n, std::min (n, k) , [&](size_t i, point rt) {
255
232
rt = conj (rt);
256
233
auto Ai = A.get (i) * rt;
257
234
auto Bi = B.get (i) * rt;
@@ -260,18 +237,21 @@ namespace cp_algo::math::fft {
260
237
int64_t A1 = llround (real (Ci));
261
238
int64_t A2 = llround (real (Bi));
262
239
res[i] = A0 + A1 * split + A2 * splitsplit;
240
+ if (n + i >= k) {
241
+ return ;
242
+ }
263
243
int64_t B0 = llround (imag (Ai));
264
244
int64_t B1 = llround (imag (Ci));
265
245
int64_t B2 = llround (imag (Bi));
266
246
res[n + i] = B0 + B1 * split + B2 * splitsplit;
267
247
});
268
248
}
269
- void mul (auto &&B, auto & res) {
270
- mul (B.A , B.B , res);
249
+ void mul (auto &&B, auto & res, size_t k ) {
250
+ mul (B.A , B.B , res, k );
271
251
}
272
252
std::vector<base> operator *= (auto &&B) {
273
- std::vector<base> res;
274
- mul (B.A , B.B , res);
253
+ std::vector<base> res ( 2 * A. size ()) ;
254
+ mul (B.A , B.B , res, size (res) );
275
255
return res;
276
256
}
277
257
@@ -288,30 +268,24 @@ namespace cp_algo::math::fft {
288
268
}
289
269
return std::max (flen, std::bit_ceil (as + bs - 1 ) / 2 );
290
270
}
291
-
292
- template < typename base>
293
- void mul ( std::vector<base> &a , std::vector<base> const & b) {
294
- if (std::min (a. size (), b. size ()) < magic) {
295
- mul_slow (a, b);
296
- return ;
271
+ void mul_truncate ( auto &a, auto const & b, size_t k) {
272
+ using base = std:: decay_t < decltype (a[ 0 ])>;
273
+ auto n = std::max (flen , std::bit_ceil (k) / 2 );
274
+ auto A = dft<base> (std::views::take (a, k), n);
275
+ if ( size (a) < k) {
276
+ a. resize (k) ;
297
277
}
298
- auto n = com_size (a.size (), b.size ());
299
- auto A = dft<base>(a, n);
300
278
if (a == b) {
301
- A.mul (dft<base>(A), a);
279
+ A.mul (dft<base>(A), a, k );
302
280
} else {
303
- A.mul (dft<base>(b, n), a);
281
+ A.mul (dft<base>(std::views::take ( b, k), n), a, k );
304
282
}
305
283
}
306
- template <typename base>
307
- void circular_mul (std::vector<base> &a, std::vector<base> const & b) {
308
- auto n = std::max (flen, std::bit_ceil (max (a.size (), b.size ())) / 2 );
309
- auto A = dft<base>(a, n);
310
- if (a == b) {
311
- A.mul (dft<base>(A), a);
312
- } else {
313
- A.mul (dft<base>(b, n), a);
314
- }
284
+ void mul (auto &a, auto const & b) {
285
+ mul_truncate (a, b, std::max (size_t (0 ), size (a) + size (b) - 1 ));
286
+ }
287
+ void circular_mul (auto &a, auto const & b) {
288
+ mul_truncate (a, b, std::max (size (a), size (b)));
315
289
}
316
290
}
317
291
#endif // CP_ALGO_MATH_FFT_HPP
0 commit comments