Skip to content

Commit 434b9be

Browse files
committed
Update
1 parent a4ef801 commit 434b9be

File tree

4 files changed

+39
-62
lines changed

4 files changed

+39
-62
lines changed

cp-algo/math/fft.hpp

Lines changed: 27 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <algorithm>
66
#include <complex>
77
#include <cassert>
8+
#include <ranges>
89
#include <vector>
910
#include <bit>
1011

@@ -33,16 +34,10 @@ namespace cp_algo::math::fft {
3334
struct cvector {
3435
static constexpr size_t pre_roots = 1 << 17;
3536
std::vector<vftype> x, y;
36-
cvector() {}
3737
cvector(size_t n) {
38-
resize(n);
39-
}
40-
void resize(size_t n) {
41-
n = std::bit_ceil(std::max<size_t>(n, flen));
42-
if(size() != n) {
43-
x.resize(n / flen);
44-
y.resize(n / flen);
45-
}
38+
n = std::max(flen, std::bit_ceil(n));
39+
x.resize(n / flen);
40+
y.resize(n / flen);
4641
}
4742
template<class pt = point>
4843
void set(size_t k, pt t) {
@@ -162,23 +157,6 @@ namespace cp_algo::math::fft {
162157
return res;
163158
}();
164159

165-
template<typename base>
166-
void mul_slow(std::vector<base> &a, const std::vector<base> &b) {
167-
if(a.empty() || b.empty()) {
168-
a.clear();
169-
} else {
170-
int n = a.size();
171-
int m = b.size();
172-
a.resize(n + m - 1);
173-
for(int k = n + m - 2; k >= 0; k--) {
174-
a[k] *= b[0];
175-
for(int j = std::max(k - n + 1, 1); j < std::min(k + 1, m); j++) {
176-
a[k] += a[k - j] * b[j];
177-
}
178-
}
179-
}
180-
}
181-
182160
template<typename base>
183161
struct dft {
184162
cvector A;
@@ -219,7 +197,7 @@ namespace cp_algo::math::fft {
219197
int split;
220198
cvector A, B;
221199

222-
dft(std::vector<base> const& a, size_t n): A(n), B(n) {
200+
dft(auto const& a, size_t n): A(n), B(n) {
223201
split = std::sqrt(base::mod());
224202
cvector::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
225203
size_t ti = std::min(i, i - n);
@@ -233,7 +211,7 @@ namespace cp_algo::math::fft {
233211
}
234212
}
235213

236-
void mul(auto &&C, auto &&D, auto &res) {
214+
void mul(auto &&C, auto &&D, auto &res, size_t k) {
237215
assert(A.size() == C.size());
238216
size_t n = A.size();
239217
if(!n) {
@@ -249,9 +227,8 @@ namespace cp_algo::math::fft {
249227
A.ifft();
250228
B.ifft();
251229
C.ifft();
252-
res.resize(2 * n);
253230
auto splitsplit = (base(split) * split).rem();
254-
cvector::exec_on_roots(2 * n, n, [&](size_t i, point rt) {
231+
cvector::exec_on_roots(2 * n, std::min(n, k), [&](size_t i, point rt) {
255232
rt = conj(rt);
256233
auto Ai = A.get(i) * rt;
257234
auto Bi = B.get(i) * rt;
@@ -260,18 +237,21 @@ namespace cp_algo::math::fft {
260237
int64_t A1 = llround(real(Ci));
261238
int64_t A2 = llround(real(Bi));
262239
res[i] = A0 + A1 * split + A2 * splitsplit;
240+
if(n + i >= k) {
241+
return;
242+
}
263243
int64_t B0 = llround(imag(Ai));
264244
int64_t B1 = llround(imag(Ci));
265245
int64_t B2 = llround(imag(Bi));
266246
res[n + i] = B0 + B1 * split + B2 * splitsplit;
267247
});
268248
}
269-
void mul(auto &&B, auto& res) {
270-
mul(B.A, B.B, res);
249+
void mul(auto &&B, auto& res, size_t k) {
250+
mul(B.A, B.B, res, k);
271251
}
272252
std::vector<base> operator *= (auto &&B) {
273-
std::vector<base> res;
274-
mul(B.A, B.B, res);
253+
std::vector<base> res(2 * A.size());
254+
mul(B.A, B.B, res, size(res));
275255
return res;
276256
}
277257

@@ -288,30 +268,24 @@ namespace cp_algo::math::fft {
288268
}
289269
return std::max(flen, std::bit_ceil(as + bs - 1) / 2);
290270
}
291-
292-
template<typename base>
293-
void mul(std::vector<base> &a, std::vector<base> const& b) {
294-
if(std::min(a.size(), b.size()) < magic) {
295-
mul_slow(a, b);
296-
return;
271+
void mul_truncate(auto &a, auto const& b, size_t k) {
272+
using base = std::decay_t<decltype(a[0])>;
273+
auto n = std::max(flen, std::bit_ceil(k) / 2);
274+
auto A = dft<base>(std::views::take(a, k), n);
275+
if(size(a) < k) {
276+
a.resize(k);
297277
}
298-
auto n = com_size(a.size(), b.size());
299-
auto A = dft<base>(a, n);
300278
if(a == b) {
301-
A.mul(dft<base>(A), a);
279+
A.mul(dft<base>(A), a, k);
302280
} else {
303-
A.mul(dft<base>(b, n), a);
281+
A.mul(dft<base>(std::views::take(b, k), n), a, k);
304282
}
305283
}
306-
template<typename base>
307-
void circular_mul(std::vector<base> &a, std::vector<base> const& b) {
308-
auto n = std::max(flen, std::bit_ceil(max(a.size(), b.size())) / 2);
309-
auto A = dft<base>(a, n);
310-
if(a == b) {
311-
A.mul(dft<base>(A), a);
312-
} else {
313-
A.mul(dft<base>(b, n), a);
314-
}
284+
void mul(auto &a, auto const& b) {
285+
mul_truncate(a, b, std::max(size_t(0), size(a) + size(b) - 1));
286+
}
287+
void circular_mul(auto &a, auto const& b) {
288+
mul_truncate(a, b, std::max(size(a), size(b)));
315289
}
316290
}
317291
#endif // CP_ALGO_MATH_FFT_HPP

cp-algo/math/poly.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@
1212
#include <optional>
1313
#include <utility>
1414
#include <vector>
15+
#include <deque>
1516
#include <list>
1617
namespace cp_algo::math {
1718
template<typename T>
1819
struct poly_t {
1920
using base = T;
20-
std::vector<T> a;
21+
std::deque<T> a;
2122

2223
void normalize() {poly::impl::normalize(*this);}
2324

2425
poly_t(){}
2526
poly_t(T a0): a{a0} {normalize();}
26-
poly_t(std::vector<T> const& t): a(t) {normalize();}
27+
poly_t(std::vector<T> const& t): a(begin(t), end(t)) {normalize();}
28+
poly_t(std::deque<T> const& t): a(t) {normalize();}
2729

2830
poly_t operator -() const {return poly::impl::neg(*this);}
2931
poly_t& operator += (poly_t const& t) {return poly::impl::add(*this, t);}
@@ -563,10 +565,10 @@ namespace cp_algo::math {
563565

564566
int N = fft::com_size((n + 1) / 2, (n + 1) / 2);
565567

566-
auto Q0f = fft::dft(Q0.a, N);
567-
auto Q1f = fft::dft(Q1.a, N);
568-
auto P0f = fft::dft(P0.a, N);
569-
auto P1f = fft::dft(P1.a, N);
568+
auto Q0f = fft::dft<T>(Q0.a, N);
569+
auto Q1f = fft::dft<T>(Q1.a, N);
570+
auto P0f = fft::dft<T>(P0.a, N);
571+
auto P1f = fft::dft<T>(P1.a, N);
570572

571573
if(k % 2) {
572574
P = poly_t(Q0f * P1f) + poly_t(Q1f * P0f);

cp-algo/math/poly/impl/div.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ namespace cp_algo::math::poly::impl {
100100

101101
int N = fft::com_size((n + 1) / 2, (n + 1) / 2);
102102

103-
auto q0f = fft::dft(q0.a, N);
104-
auto q1f = fft::dft(q1.a, N);
103+
auto q0f = fft::dft<typename poly::base>(q0.a, N);
104+
auto q1f = fft::dft<typename poly::base>(q1.a, N);
105105

106106
// Q(x)*Q(-x) = Q0(x^2)^2 - x^2 Q1(x^2)^2
107-
auto qqf = fft::dft(inv(
107+
auto qqf = fft::dft<typename poly::base>(inv(
108108
poly(q0f * q0f) - poly(q1f * q1f).mul_xk(1)
109109
, (n + 1) / 2).a, N);
110110

verify/poly/inv.test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// @brief Inv of Power Series
22
#define PROBLEM "https://judge.yosupo.jp/problem/inv_of_formal_power_series"
3+
#pragma GCC optimize("Ofast,unroll-loops")
34
#include "cp-algo/math/poly.hpp"
45
#include <bits/stdc++.h>
56

0 commit comments

Comments
 (0)