Skip to content

Commit 3f55a98

Browse files
committed
support multivariate convolution (truncated)
1 parent 5136241 commit 3f55a98

File tree

3 files changed

+134
-5
lines changed

3 files changed

+134
-5
lines changed

cp-algo/math/fft.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace cp_algo::math::fft {
2929
}
3030
}
3131

32+
dft(size_t n): A(n), B(n) {init();}
3233
dft(auto const& a, size_t n): A(n), B(n) {
3334
init();
3435
base b2x32 = bpow(base(2), 32);
@@ -69,8 +70,8 @@ namespace cp_algo::math::fft {
6970
B.fft();
7071
}
7172
}
72-
73-
void dot(auto &&C, auto const& D) {
73+
template<bool overwrite = true>
74+
void dot(auto const& C, auto const& D, auto &Aout, auto &Bout, auto &Cout) const {
7475
cvector::exec_on_evals<1>(A.size() / flen, [&](size_t k, point rt) {
7576
k *= flen;
7677
auto [Ax, Ay] = A.at(k);
@@ -93,13 +94,23 @@ namespace cp_algo::math::fft {
9394
real(Dv)[0] = dx * real(rt) - dy * imag(rt);
9495
imag(Dv)[0] = dx * imag(rt) + dy * real(rt);
9596
}
96-
A.at(k) = AC;
97-
C.at(k) = AD + BC;
98-
B.at(k) = BD;
97+
if(overwrite) {
98+
Aout.at(k) = AC;
99+
Cout.at(k) = AD + BC;
100+
Bout.at(k) = BD;
101+
} else {
102+
Aout.at(k) += AC;
103+
Cout.at(k) += AD + BC;
104+
Bout.at(k) += BD;
105+
}
99106
});
100107
checkpoint("dot");
101108
}
102109

110+
void dot(auto &&C, auto const& D) {
111+
dot(C, D, A, B, C);
112+
}
113+
103114
void recover_mod(auto &&C, auto &res, size_t k) {
104115
size_t check = (k + flen - 1) / flen * flen;
105116
assert(res.size() >= check);

cp-algo/math/multivar.hpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#ifndef CP_ALGO_MATH_MULTIVAR_HPP
2+
#define CP_ALGO_MATH_MULTIVAR_HPP
3+
#include "../util/big_alloc.hpp"
4+
#include "../number_theory/modint.hpp"
5+
#include "../math/fft.hpp"
6+
namespace cp_algo::math::fft {
7+
template<modint_type base>
8+
struct multivar {
9+
std::vector<base, cp_algo::big_alloc<base>> data;
10+
std::vector<size_t, cp_algo::big_alloc<size_t>> ranks;
11+
std::vector<size_t> dim;
12+
size_t N;
13+
size_t rank(size_t i) {
14+
size_t cur = 1, res = 0, K = size(dim);
15+
for(auto ni: dim) {
16+
cur *= ni;
17+
res += i / cur;
18+
}
19+
return res % K;
20+
}
21+
multivar(std::vector<size_t> const& dim): dim(dim), N(
22+
std::ranges::fold_left(dim, 1, std::multiplies{})
23+
) {
24+
data.resize(N);
25+
ranks.resize(N);
26+
for(auto [i, x]: ranks | std::views::enumerate) {
27+
x = rank(i);
28+
}
29+
}
30+
void read() {
31+
for(auto &it: data) {
32+
std::cin >> it;
33+
}
34+
}
35+
void print() {
36+
for(auto &it: data) {
37+
std::cout << it << " ";
38+
}
39+
std::cout << "\n";
40+
}
41+
void mul(multivar<base> const& b) {
42+
assert(dim == b.dim);
43+
size_t K = size(dim);
44+
if(K == 0) {
45+
data[0] *= b.data[0];
46+
return;
47+
}
48+
std::vector<dft<base>> A, B;
49+
size_t M = std::max(flen, std::bit_ceil(2 * N - 1) / 2);
50+
for(size_t i = 0; i < K; i++) {
51+
A.emplace_back(data | std::views::enumerate | std::views::transform(
52+
[&](auto jx) {
53+
auto [j, x] = jx;
54+
return ranks[j] == i ? x : base(0);
55+
}
56+
), M);
57+
B.emplace_back(b.data | std::views::enumerate | std::views::transform(
58+
[&](auto jx) {
59+
auto [j, x] = jx;
60+
return ranks[j] == i ? x : base(0);
61+
}
62+
), M);
63+
}
64+
std::vector<dft<base>> C;
65+
for(size_t i = 0; i < K; i++) {
66+
dft<base> C(M);
67+
cvector X = C.A;
68+
for(size_t j = 0; j < K; j++) {
69+
size_t tj = (i - j + K) % K;
70+
A[j].template dot<false>(B[tj].A, B[tj].B, C.A, C.B, X);
71+
}
72+
std::vector<base, cp_algo::big_alloc<base>> res((N + flen - 1) / flen * flen);
73+
C.A.ifft();
74+
C.B.ifft();
75+
X.ifft();
76+
C.recover_mod(X, res, N);
77+
for(size_t j = 0; j < N; j++) {
78+
if(i == ranks[j]) {
79+
data[j] = res[j];
80+
}
81+
}
82+
}
83+
}
84+
};
85+
}
86+
#endif // CP_ALGO_MATH_MULTIVAR_HPP

verify/poly/multivar.test.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// @brief Multidimensional Convolution (Truncated)
2+
#define PROBLEM "https://judge.yosupo.jp/problem/multivariate_convolution"
3+
#pragma GCC optimize("Ofast,unroll-loops")
4+
#define CP_ALGO_CHECKPOINT
5+
#include <bits/stdc++.h>
6+
#include "blazingio/blazingio.min.hpp"
7+
#include "cp-algo/math/multivar.hpp"
8+
9+
using namespace std;
10+
using namespace cp_algo::math::fft;
11+
12+
const int mod = 998244353;
13+
using base = cp_algo::math::modint<mod>;
14+
15+
void solve() {
16+
int k;
17+
cin >> k;
18+
vector<size_t> N(k);
19+
for(auto &n: N) cin >> n;
20+
multivar<base> a(N), b(N);
21+
a.read();
22+
b.read();
23+
a.mul(b);
24+
a.print();
25+
}
26+
27+
signed main() {
28+
//freopen("input.txt", "r", stdin);
29+
ios::sync_with_stdio(0);
30+
cin.tie(0);
31+
solve();
32+
}

0 commit comments

Comments
 (0)