Skip to content

Commit 51732a6

Browse files
committed
support full dft
1 parent af553ce commit 51732a6

File tree

3 files changed

+80
-30
lines changed

3 files changed

+80
-30
lines changed

cp-algo/math/cvector.hpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,26 @@ namespace cp_algo::math::fft {
109109
});
110110
checkpoint("dot");
111111
}
112-
112+
template<bool partial = true>
113113
void ifft() {
114114
size_t n = size();
115+
if constexpr (!partial) {
116+
point pi(0, 1);
117+
exec_on_evals<4>(n / 4, [&](size_t k, point rt) {
118+
k *= 4;
119+
point v1 = conj(rt);
120+
point v2 = v1 * v1;
121+
point v3 = v1 * v2;
122+
auto A = get(k);
123+
auto B = get(k + 1);
124+
auto C = get(k + 2);
125+
auto D = get(k + 3);
126+
set(k, (A + B) + (C + D));
127+
set(k + 2, ((A + B) - (C + D)) * v2);
128+
set(k + 1, ((A - B) - pi * (C - D)) * v1);
129+
set(k + 3, ((A - B) + pi * (C - D)) * v3);
130+
});
131+
}
115132
bool parity = std::countr_zero(n) % 2;
116133
if(parity) {
117134
exec_on_evals<2>(n / (2 * flen), [&](size_t k, point rt) {
@@ -147,9 +164,14 @@ namespace cp_algo::math::fft {
147164
}
148165
checkpoint("ifft");
149166
for(size_t k = 0; k < n; k += flen) {
150-
set(k, get<vpoint>(k) /= vz + (ftype)(n / flen));
167+
if constexpr (partial) {
168+
set(k, get<vpoint>(k) /= vz + ftype(n / flen));
169+
} else {
170+
set(k, get<vpoint>(k) /= vz + ftype(n));
171+
}
151172
}
152173
}
174+
template<bool partial = true>
153175
void fft() {
154176
size_t n = size();
155177
bool parity = std::countr_zero(n) % 2;
@@ -185,6 +207,23 @@ namespace cp_algo::math::fft {
185207
at(k) += t;
186208
});
187209
}
210+
if constexpr (!partial) {
211+
point pi(0, 1);
212+
exec_on_evals<4>(n / 4, [&](size_t k, point rt) {
213+
k *= 4;
214+
point v1 = rt;
215+
point v2 = v1 * v1;
216+
point v3 = v1 * v2;
217+
auto A = get(k);
218+
auto B = get(k + 1) * v1;
219+
auto C = get(k + 2) * v2;
220+
auto D = get(k + 3) * v3;
221+
set(k, (A + C) + (B + D));
222+
set(k + 1, (A + C) - (B + D));
223+
set(k + 2, (A - C) + pi * (B - D));
224+
set(k + 3, (A - C) - pi * (B - D));
225+
});
226+
}
188227
checkpoint("fft");
189228
}
190229
static constexpr size_t pre_evals = 1 << 16;

cp-algo/math/fft.hpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace cp_algo::math::fft {
3030
}
3131

3232
dft(size_t n): A(n), B(n) {init();}
33-
dft(auto const& a, size_t n): A(n), B(n) {
33+
dft(auto const& a, size_t n, bool partial = true): A(n), B(n) {
3434
init();
3535
base b2x32 = bpow(base(2), 32);
3636
u64x4 cur = {
@@ -66,35 +66,47 @@ namespace cp_algo::math::fft {
6666
}
6767
checkpoint("dft init");
6868
if(n) {
69-
A.fft();
70-
B.fft();
69+
if(partial) {
70+
A.fft();
71+
B.fft();
72+
} else {
73+
A.template fft<false>();
74+
B.template fft<false>();
75+
}
7176
}
7277
}
73-
template<bool overwrite = true>
78+
template<bool overwrite = true, bool partial = true>
7479
void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
7580
cvector::exec_on_evals<1>(A.size() / flen, [&](size_t k, point rt) {
7681
k *= flen;
77-
auto [Ax, Ay] = A.at(k);
78-
auto [Bx, By] = B.at(k);
7982
vpoint AC, AD, BC, BD;
8083
AC = AD = BC = BD = vz;
8184
auto Cv = C.at(k), Dv = D.at(k);
82-
for (size_t i = 0; i < flen; i++) {
83-
vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
84-
AC += Av * Cv; AD += Av * Dv;
85-
BC += Bv * Cv; BD += Bv * Dv;
86-
real(Cv) = rotate_right(real(Cv));
87-
imag(Cv) = rotate_right(imag(Cv));
88-
real(Dv) = rotate_right(real(Dv));
89-
imag(Dv) = rotate_right(imag(Dv));
90-
auto cx = real(Cv)[0], cy = imag(Cv)[0];
91-
auto dx = real(Dv)[0], dy = imag(Dv)[0];
92-
real(Cv)[0] = cx * real(rt) - cy * imag(rt);
93-
imag(Cv)[0] = cx * imag(rt) + cy * real(rt);
94-
real(Dv)[0] = dx * real(rt) - dy * imag(rt);
95-
imag(Dv)[0] = dx * imag(rt) + dy * real(rt);
85+
if constexpr(partial) {
86+
auto [Ax, Ay] = A.at(k);
87+
auto [Bx, By] = B.at(k);
88+
for (size_t i = 0; i < flen; i++) {
89+
vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
90+
AC += Av * Cv; AD += Av * Dv;
91+
BC += Bv * Cv; BD += Bv * Dv;
92+
real(Cv) = rotate_right(real(Cv));
93+
imag(Cv) = rotate_right(imag(Cv));
94+
real(Dv) = rotate_right(real(Dv));
95+
imag(Dv) = rotate_right(imag(Dv));
96+
auto cx = real(Cv)[0], cy = imag(Cv)[0];
97+
auto dx = real(Dv)[0], dy = imag(Dv)[0];
98+
real(Cv)[0] = cx * real(rt) - cy * imag(rt);
99+
imag(Cv)[0] = cx * imag(rt) + cy * real(rt);
100+
real(Dv)[0] = dx * real(rt) - dy * imag(rt);
101+
imag(Dv)[0] = dx * imag(rt) + dy * real(rt);
102+
}
103+
} else {
104+
AC = A.at(k) * Cv;
105+
AD = A.at(k) * Dv;
106+
BC = B.at(k) * Cv;
107+
BD = B.at(k) * Dv;
96108
}
97-
if(overwrite) {
109+
if constexpr (overwrite) {
98110
Aout.at(k) = AC;
99111
Cout.at(k) = AD + BC;
100112
Bout.at(k) = BD;

cp-algo/math/multivar.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,26 @@ namespace cp_algo::math::fft {
5656
auto [j, x] = jx;
5757
return ranks[j] == i ? x : base(0);
5858
}
59-
), M);
59+
), M, false);
6060
B.emplace_back(b.data | std::views::enumerate | std::views::transform(
6161
[&](auto jx) {
6262
auto [j, x] = jx;
6363
return ranks[j] == i ? x : base(0);
6464
}
65-
), M);
65+
), M, false);
6666
}
67-
std::vector<dft<base>> C;
6867
for(size_t i = 0; i < K; i++) {
6968
dft<base> C(M);
7069
cvector X = C.A;
7170
for(size_t j = 0; j < K; j++) {
7271
size_t tj = (i - j + K) % K;
73-
A[j].template dot<false>(B[tj].A, B[tj].B, C.A, C.B, X);
72+
A[j].template dot<false, false>(B[tj].A, B[tj].B, C.A, C.B, X);
7473
}
7574
checkpoint("dot");
7675
std::vector<base, cp_algo::big_alloc<base>> res((N + flen - 1) / flen * flen);
77-
C.A.ifft();
78-
C.B.ifft();
79-
X.ifft();
76+
C.A.template ifft<false>();
77+
C.B.template ifft<false>();
78+
X.template ifft<false>();
8079
C.recover_mod(X, res, N);
8180
for(size_t j = 0; j < N; j++) {
8281
if(i == ranks[j]) {

0 commit comments

Comments
 (0)