Skip to content

Commit b4c2bb6

Browse files
committed
Convolution imaginary-cyclic speedup
1 parent 08816d6 commit b4c2bb6

File tree

1 file changed

+36
-30
lines changed

1 file changed

+36
-30
lines changed

cp-algo/math/fft.hpp

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ namespace cp_algo::math::fft {
7070
return {x[k / flen], y[k / flen]};
7171
}
7272
}
73+
vpoint vget(size_t k) const {
74+
return get<vpoint>(k);
75+
}
76+
7377
size_t size() const {
7478
return flen * std::size(x);
7579
}
@@ -127,8 +131,8 @@ namespace cp_algo::math::fft {
127131
}
128132
};
129133
const cvector cvector::roots = []() {
130-
cvector res(maxn);
131-
for(size_t n = 1; n < maxn; n *= 2) {
134+
cvector res(2 * maxn);
135+
for(size_t n = 1; n < res.size(); n *= 2) {
132136
auto base = std::polar(1., std::numbers::pi / n);
133137
point cur = 1;
134138
for(size_t k = 0; k < n; k++) {
@@ -197,50 +201,52 @@ namespace cp_algo::math::fft {
197201
template<modint_type base>
198202
struct dft<base> {
199203
static constexpr int split = 1 << 15;
200-
cvector A;
204+
cvector A, B;
201205

202-
dft(std::vector<base> const& a, size_t n): A(n) {
203-
for(size_t i = 0; i < std::min(n, a.size()); i++) {
204-
A.set(i, point{
205-
ftype(a[i].rem() % split),
206-
ftype(a[i].rem() / split)
207-
});
206+
dft(std::vector<base> const& a, size_t n): A(n), B(n) {
207+
for(size_t i = 0; i < size(a); i++) {
208+
A.set(i % n, A.get(i % n) + ftype(a[i].rem() % split) * cvector::roots.get(2 * n + i));
209+
B.set(i % n, B.get(i % n) + ftype(a[i].rem() / split) * cvector::roots.get(2 * n + i));
208210
}
209211
if(n) {
210212
A.fft();
213+
B.fft();
211214
}
212215
}
213216

214-
std::vector<base> mul(auto &&B) {
215-
assert(A.size() == B.size());
217+
std::vector<base> mul(auto &&C, auto &&D) {
218+
assert(A.size() == C.size());
216219
size_t n = A.size();
217220
if(!n) {
218221
return std::vector<base>();
219222
}
220-
for(size_t i = 0; 2 * i <= n; i++) {
221-
size_t j = (n - i) % n;
222-
size_t x = bitreverse(n, i);
223-
size_t y = bitreverse(n, j);
224-
auto Ax = A.get(x), Bx = B.get(x);
225-
auto Ay = A.get(y), By = B.get(y);
226-
B.set(x, Ax * (Bx + conj(By)));
227-
A.set(x, Ax * (Bx - conj(By)));
228-
B.set(y, Ay * (By + conj(Bx)));
229-
A.set(y, Ay * (By - conj(Bx)));
223+
for(size_t i = 0; i < n; i += flen) {
224+
auto tmp = A.vget(i) * D.vget(i) + B.vget(i) * C.vget(i);
225+
A.set(i, A.vget(i) * C.vget(i));
226+
B.set(i, B.vget(i) * D.vget(i));
227+
C.set(i, tmp);
230228
}
231229
A.ifft();
232230
B.ifft();
233-
std::vector<base> res(n);
231+
C.ifft();
232+
std::vector<base> res(2 * n);
234233
for(size_t i = 0; i < n; i++) {
235-
base A0 = llround(B.get(i).real()) / 2;
236-
base A1 = llround(B.get(i).imag() + A.get(i).imag()) / 2;
237-
base A2 = llround(A.get(i).real()) / 2;
238-
res[i] = A0 + A1 * split - A2 * split * split;
234+
auto Ai = A.get(i) * conj(cvector::roots.get(2 * n + i));
235+
auto Bi = B.get(i) * conj(cvector::roots.get(2 * n + i));
236+
auto Ci = C.get(i) * conj(cvector::roots.get(2 * n + i));
237+
base A0 = llround(real(Ai));
238+
base A1 = llround(real(Ci));
239+
base A2 = llround(real(Bi));
240+
res[i] = A0 + A1 * split + A2 * split * split;
241+
base B0 = llround(imag(Ai));
242+
base B1 = llround(imag(Ci));
243+
base B2 = llround(imag(Bi));
244+
res[n + i] = B0 + B1 * split + B2 * split * split;
239245
}
240246
return res;
241247
}
242248
std::vector<base> operator *= (auto &&B) {
243-
return mul(B.A);
249+
return mul(B.A, B.B);
244250
}
245251

246252
auto operator * (dft const& B) const {
@@ -254,12 +260,12 @@ namespace cp_algo::math::fft {
254260
if(!as || !bs) {
255261
return 0;
256262
}
257-
return std::bit_ceil(as + bs - 1);
263+
return std::max(flen, std::bit_ceil(as + bs - 1) / 2);
258264
}
259265

260266
template<typename base>
261267
void mul(std::vector<base> &a, std::vector<base> const& b) {
262-
if(std::min(a.size(), b.size()) < magic) {
268+
if(std::min(a.size(), b.size()) < 1) {
263269
mul_slow(a, b);
264270
return;
265271
}
@@ -273,7 +279,7 @@ namespace cp_algo::math::fft {
273279
}
274280
template<typename base>
275281
void circular_mul(std::vector<base> &a, std::vector<base> const& b) {
276-
auto n = std::bit_ceil(a.size());
282+
auto n = std::max(flen, std::bit_ceil(max(a.size(), b.size())) / 2);
277283
auto A = dft<base>(a, n);
278284
if(a == b) {
279285
a = A *= dft<base>(A);

0 commit comments

Comments
 (0)