Skip to content

Commit 5b1d62b

Browse files
committed
Improve dft
1 parent 32f6841 commit 5b1d62b

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

cp-algo/math/fft.hpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ namespace cp_algo::math::fft {
212212
}
213213
}
214214

215-
void mul(auto &&C, auto &&D, auto &res, size_t k) {
215+
void mul(auto &&C, auto const& D, auto &res, size_t k) {
216216
assert(A.size() == C.size());
217217
size_t n = A.size();
218218
if(!n) {
@@ -247,17 +247,24 @@ namespace cp_algo::math::fft {
247247
res[n + i] = B0 + B1 * split + B2 * splitsplit;
248248
});
249249
}
250-
void mul(auto &&B, auto& res, size_t k) {
250+
void mul_inplace(auto &B, auto& res, size_t k) {
251251
mul(B.A, B.B, res, k);
252252
}
253-
std::vector<base> operator *= (auto &&B) {
253+
void mul(auto const& B, auto& res, size_t k) {
254+
mul(cvector(B.A), B.B, res, k);
255+
}
256+
std::vector<base> operator *= (dft &B) {
254257
std::vector<base> res(2 * A.size());
255-
mul(B.A, B.B, res, size(res));
258+
mul_inplace(B, res, size(res));
259+
return res;
260+
}
261+
std::vector<base> operator *= (dft const& B) {
262+
std::vector<base> res(2 * A.size());
263+
mul(B, res, size(res));
256264
return res;
257265
}
258-
259266
auto operator * (dft const& B) const {
260-
return dft(*this) *= dft(B);
267+
return dft(*this) *= B;
261268
}
262269

263270
point operator [](int i) const {return A.get(i);}
@@ -296,9 +303,9 @@ namespace cp_algo::math::fft {
296303
a.resize(k);
297304
auto A = dft<base>(a, n);
298305
if(&a == &b) {
299-
A.mul(dft<base>(A), a, k);
306+
A.mul(A, a, k);
300307
} else {
301-
A.mul(dft<base>(std::views::take(b, k), n), a, k);
308+
A.mul_inplace(dft<base>(std::views::take(b, k), n), a, k);
302309
}
303310
}
304311
void mul(auto &a, auto const& b) {

cp-algo/math/poly/impl/div.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ namespace cp_algo::math::poly::impl {
8989
auto qqf = fft::dft<base>(qq.a, N);
9090
int M = q0.deg() + (n + 1) / 2;
9191
std::deque<base> A(M), B(M);
92-
q0f.mul(fft::dft<base>(qqf), A, M);
93-
q1f.mul(qqf, B, M);
92+
q0f.mul(qqf, A, M);
93+
q1f.mul_inplace(qqf, B, M);
9494
q.a.resize(n + 1);
9595
for(size_t i = 0; i < n; i += 2) {
9696
q.a[i] = A[q0.deg() + i / 2];
@@ -122,8 +122,8 @@ namespace cp_algo::math::poly::impl {
122122
auto qqf = fft::dft<base>(qq.a, N);
123123

124124
std::deque<base> A((n + 1) / 2), B((n + 1) / 2);
125-
q0f.mul(fft::dft<base>(qqf), A, (n + 1) / 2);
126-
q1f.mul(qqf, B, (n + 1) / 2);
125+
q0f.mul(qqf, A, (n + 1) / 2);
126+
q1f.mul_inplace(qqf, B, (n + 1) / 2);
127127
p.a.resize(n + 1);
128128
for(size_t i = 0; i < n; i += 2) {
129129
p.a[i] = A[i / 2];

0 commit comments

Comments
 (0)