@@ -212,7 +212,7 @@ namespace cp_algo::math::fft {
212
212
}
213
213
}
214
214
215
- void mul (auto &&C, auto && D, auto &res, size_t k) {
215
+ void mul (auto &&C, auto const & D, auto &res, size_t k) {
216
216
assert (A.size () == C.size ());
217
217
size_t n = A.size ();
218
218
if (!n) {
@@ -247,17 +247,24 @@ namespace cp_algo::math::fft {
247
247
res[n + i] = B0 + B1 * split + B2 * splitsplit;
248
248
});
249
249
}
250
- void mul (auto & &B, auto & res, size_t k) {
250
+ void mul_inplace (auto &B, auto & res, size_t k) {
251
251
mul (B.A , B.B , res, k);
252
252
}
253
- std::vector<base> operator *= (auto &&B) {
253
+ void mul (auto const & B, auto & res, size_t k) {
254
+ mul (cvector (B.A ), B.B , res, k);
255
+ }
256
+ std::vector<base> operator *= (dft &B) {
254
257
std::vector<base> res (2 * A.size ());
255
- mul (B.A , B.B , res, size (res));
258
+ mul_inplace (B, res, size (res));
259
+ return res;
260
+ }
261
+ std::vector<base> operator *= (dft const & B) {
262
+ std::vector<base> res (2 * A.size ());
263
+ mul (B, res, size (res));
256
264
return res;
257
265
}
258
-
259
266
auto operator * (dft const & B) const {
260
- return dft (*this ) *= dft (B) ;
267
+ return dft (*this ) *= B ;
261
268
}
262
269
263
270
point operator [](int i) const {return A.get (i);}
@@ -296,9 +303,9 @@ namespace cp_algo::math::fft {
296
303
a.resize (k);
297
304
auto A = dft<base>(a, n);
298
305
if (&a == &b) {
299
- A.mul (dft<base>(A) , a, k);
306
+ A.mul (A , a, k);
300
307
} else {
301
- A.mul (dft<base>(std::views::take (b, k), n), a, k);
308
+ A.mul_inplace (dft<base>(std::views::take (b, k), n), a, k);
302
309
}
303
310
}
304
311
void mul (auto &a, auto const & b) {
0 commit comments