Skip to content

Commit 08816d6

Browse files
committed
Improve convolution
1 parent 2348db3 commit 08816d6

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

cp-algo/math/fft.hpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace cp_algo::math::fft {
4646
resize(n);
4747
}
4848
void resize(size_t n) {
49-
n = std::bit_ceil(std::max<size_t>(n, flen));
49+
n = std::bit_ceil(std::max<size_t>(n, 4));
5050
if(size() != n) {
5151
x.resize(n / flen);
5252
y.resize(n / flen);
@@ -201,45 +201,50 @@ namespace cp_algo::math::fft {
201201

202202
dft(std::vector<base> const& a, size_t n): A(n) {
203203
for(size_t i = 0; i < std::min(n, a.size()); i++) {
204-
A.set(i, point{a[i].rem() % split, a[i].rem() / split});
204+
A.set(i, point{
205+
ftype(a[i].rem() % split),
206+
ftype(a[i].rem() / split)
207+
});
205208
}
206209
if(n) {
207210
A.fft();
208211
}
209212
}
210213

211-
std::vector<base> operator *= (dft const& B) {
212-
assert(A.size() == B.A.size());
214+
std::vector<base> mul(auto &&B) {
215+
assert(A.size() == B.size());
213216
size_t n = A.size();
214217
if(!n) {
215218
return std::vector<base>();
216219
}
217-
cvector C(n);
218220
for(size_t i = 0; 2 * i <= n; i++) {
219221
size_t j = (n - i) % n;
220222
size_t x = bitreverse(n, i);
221223
size_t y = bitreverse(n, j);
222-
auto Ax = A.get(x), Bx = B[x];
223-
auto Ay = A.get(y), By = B[y];
224-
C.set(x, Ax * (Bx + conj(By)));
224+
auto Ax = A.get(x), Bx = B.get(x);
225+
auto Ay = A.get(y), By = B.get(y);
226+
B.set(x, Ax * (Bx + conj(By)));
225227
A.set(x, Ax * (Bx - conj(By)));
226-
C.set(y, Ay * (By + conj(Bx)));
228+
B.set(y, Ay * (By + conj(Bx)));
227229
A.set(y, Ay * (By - conj(Bx)));
228230
}
229231
A.ifft();
230-
C.ifft();
232+
B.ifft();
231233
std::vector<base> res(n);
232234
for(size_t i = 0; i < n; i++) {
233-
base A0 = llround(C.get(i).real()) / 2;
234-
base A1 = llround(C.get(i).imag() + A.get(i).imag()) / 2;
235+
base A0 = llround(B.get(i).real()) / 2;
236+
base A1 = llround(B.get(i).imag() + A.get(i).imag()) / 2;
235237
base A2 = llround(A.get(i).real()) / 2;
236238
res[i] = A0 + A1 * split - A2 * split * split;
237239
}
238240
return res;
239241
}
242+
std::vector<base> operator *= (auto &&B) {
243+
return mul(B.A);
244+
}
240245

241246
auto operator * (dft const& B) const {
242-
return dft(*this) *= B;
247+
return dft(*this) *= dft(B);
243248
}
244249

245250
point operator [](int i) const {return A.get(i);}
@@ -261,7 +266,7 @@ namespace cp_algo::math::fft {
261266
auto n = com_size(a.size(), b.size());
262267
auto A = dft<base>(a, n);
263268
if(a == b) {
264-
a = A *= A;
269+
a = A *= dft<base>(A);
265270
} else {
266271
a = A *= dft<base>(b, n);
267272
}
@@ -271,7 +276,7 @@ namespace cp_algo::math::fft {
271276
auto n = std::bit_ceil(a.size());
272277
auto A = dft<base>(a, n);
273278
if(a == b) {
274-
a = A *= A;
279+
a = A *= dft<base>(A);
275280
} else {
276281
a = A *= dft<base>(b, n);
277282
}

0 commit comments

Comments
 (0)