Skip to content

Commit 36f1160

Browse files
committed
Improve convolution
1 parent b4c2bb6 commit 36f1160

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

cp-algo/math/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace cp_algo::math {
66
#ifdef CP_ALGO_MAXN
77
const int maxn = CP_ALGO_MAXN;
88
#else
9-
const int maxn = 1 << 20;
9+
const int maxn = 1 << 19;
1010
#endif
1111
const int magic = 250; // threshold for sizes to run the naive algo
1212

cp-algo/math/fft.hpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ namespace cp_algo::math::fft {
145145
}
146146
return res;
147147
}();
148+
point root(size_t n, size_t k) {
149+
if(n < maxn) {
150+
return cvector::roots.get(n + k);
151+
} else if(k % 2 == 0) {
152+
return root(n / 2, k / 2);
153+
} else {
154+
return std::polar(1., std::numbers::pi * k / n);
155+
}
156+
}
148157

149158
template<typename base>
150159
void mul_slow(std::vector<base> &a, const std::vector<base> &b) {
@@ -202,12 +211,22 @@ namespace cp_algo::math::fft {
202211
struct dft<base> {
203212
static constexpr int split = 1 << 15;
204213
cvector A, B;
214+
215+
void exec_on_roots(size_t n, size_t m, auto &&callback) {
216+
point cur = 1;
217+
point step = root(n, 1);
218+
for(size_t i = 0; i < m; i++) {
219+
callback(i, cur);
220+
cur = (i & 15) == 0 || 2 * n < maxn ? root(n, i + 1) : cur * step;
221+
}
222+
}
205223

206224
dft(std::vector<base> const& a, size_t n): A(n), B(n) {
207-
for(size_t i = 0; i < size(a); i++) {
208-
A.set(i % n, A.get(i % n) + ftype(a[i].rem() % split) * cvector::roots.get(2 * n + i));
209-
B.set(i % n, B.get(i % n) + ftype(a[i].rem() / split) * cvector::roots.get(2 * n + i));
210-
}
225+
exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
226+
A.set(i % n, A.get(i % n) + ftype(a[i].rem() % split) * rt);
227+
B.set(i % n, B.get(i % n) + ftype(a[i].rem() / split) * rt);
228+
229+
});
211230
if(n) {
212231
A.fft();
213232
B.fft();
@@ -230,10 +249,11 @@ namespace cp_algo::math::fft {
230249
B.ifft();
231250
C.ifft();
232251
std::vector<base> res(2 * n);
233-
for(size_t i = 0; i < n; i++) {
234-
auto Ai = A.get(i) * conj(cvector::roots.get(2 * n + i));
235-
auto Bi = B.get(i) * conj(cvector::roots.get(2 * n + i));
236-
auto Ci = C.get(i) * conj(cvector::roots.get(2 * n + i));
252+
exec_on_roots(2 * n, n, [&](size_t i, point rt) {
253+
rt = conj(rt);
254+
auto Ai = A.get(i) * rt;
255+
auto Bi = B.get(i) * rt;
256+
auto Ci = C.get(i) * rt;
237257
base A0 = llround(real(Ai));
238258
base A1 = llround(real(Ci));
239259
base A2 = llround(real(Bi));
@@ -242,7 +262,7 @@ namespace cp_algo::math::fft {
242262
base B1 = llround(imag(Ci));
243263
base B2 = llround(imag(Bi));
244264
res[n + i] = B0 + B1 * split + B2 * split * split;
245-
}
265+
});
246266
return res;
247267
}
248268
std::vector<base> operator *= (auto &&B) {

0 commit comments

Comments
 (0)