Skip to content

Commit ca0ad54

Browse files
committed
Taylor shift test, faster wildcard matching
1 parent fd7fde8 commit ca0ad54

File tree

4 files changed

+104
-39
lines changed

4 files changed

+104
-39
lines changed

cp-algo/math/fft.hpp

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <complex>
77
#include <cassert>
88
#include <vector>
9+
#include <bit>
910
namespace cp_algo::math::fft {
1011
using ftype = double;
1112
using point = std::complex<ftype>;
@@ -70,48 +71,89 @@ namespace cp_algo::math::fft {
7071
}
7172
}
7273

73-
template<modint_type base>
74+
template<typename base>
7475
struct dft {
76+
std::vector<point> A;
77+
78+
dft(std::vector<base> const& a, size_t n): A(n) {
79+
for(size_t i = 0; i < std::min(n, a.size()); i++) {
80+
A[i] = a[i];
81+
}
82+
if(n) {
83+
fft(A, n);
84+
}
85+
}
86+
87+
auto operator *= (dft const& B) {
88+
assert(A.size() == B.A.size());
89+
size_t n = A.size();
90+
if(!n) {
91+
return std::vector<base>();
92+
}
93+
for(size_t i = 0; i < n; i++) {
94+
A[i] *= B[i];
95+
}
96+
fft(A, n);
97+
reverse(begin(A) + 1, end(A));
98+
std::vector<base> res(n);
99+
for(size_t i = 0; i < n; i++) {
100+
res[i] = A[i];
101+
res[i] /= n;
102+
}
103+
return res;
104+
}
105+
106+
auto operator * (dft const& B) const {
107+
return dft(*this) *= B;
108+
}
109+
110+
point& operator [](int i) {return A[i];}
111+
point operator [](int i) const {return A[i];}
112+
};
113+
114+
template<modint_type base>
115+
struct dft<base> {
75116
static constexpr int split = 1 << 15;
76117
std::vector<point> A;
77118

78119
dft(std::vector<base> const& a, size_t n): A(n) {
79120
for(size_t i = 0; i < std::min(n, a.size()); i++) {
80-
A[i] = point(
81-
a[i].rem() % split,
82-
a[i].rem() / split
83-
);
121+
A[i] = point(a[i].rem() % split, a[i].rem() / split);
84122
}
85123
if(n) {
86124
fft(A, n);
87125
}
88126
}
89127

90-
auto operator * (dft const& B) {
128+
auto operator *= (dft const& B) {
91129
assert(A.size() == B.A.size());
92130
size_t n = A.size();
93131
if(!n) {
94132
return std::vector<base>();
95133
}
96-
std::vector<point> C(n), D(n);
134+
std::vector<point> C(n);
97135
for(size_t i = 0; i < n; i++) {
98136
C[i] = A[i] * (B[i] + conj(B[(n - i) % n]));
99-
D[i] = A[i] * (B[i] - conj(B[(n - i) % n]));
137+
A[i] = A[i] * (B[i] - conj(B[(n - i) % n]));
100138
}
101139
fft(C, n);
102-
fft(D, n);
140+
fft(A, n);
103141
reverse(begin(C) + 1, end(C));
104-
reverse(begin(D) + 1, end(D));
142+
reverse(begin(A) + 1, end(A));
105143
int t = 2 * n;
106144
std::vector<base> res(n);
107145
for(size_t i = 0; i < n; i++) {
108146
base A0 = llround(C[i].real() / t);
109-
base A1 = llround(C[i].imag() / t + D[i].imag() / t);
110-
base A2 = llround(D[i].real() / t);
147+
base A1 = llround(C[i].imag() / t + A[i].imag() / t);
148+
base A2 = llround(A[i].real() / t);
111149
res[i] = A0 + A1 * split - A2 * split * split;
112150
}
113151
return res;
114152
}
153+
154+
auto operator * (dft const& B) const {
155+
return dft(*this) *= B;
156+
}
115157

116158
point& operator [](int i) {return A[i];}
117159
point operator [](int i) const {return A[i];}
@@ -121,14 +163,10 @@ namespace cp_algo::math::fft {
121163
if(!as || !bs) {
122164
return 0;
123165
}
124-
size_t n = as + bs - 1;
125-
while(__builtin_popcount(n) != 1) {
126-
n++;
127-
}
128-
return n;
166+
return std::bit_ceil(as + bs - 1);
129167
}
130168

131-
template<modint_type base>
169+
template<typename base>
132170
void mul(std::vector<base> &a, std::vector<base> const& b) {
133171
if(std::min(a.size(), b.size()) < magic) {
134172
mul_slow(a, b);
@@ -137,30 +175,19 @@ namespace cp_algo::math::fft {
137175
auto n = com_size(a.size(), b.size());
138176
auto A = dft<base>(a, n);
139177
if(a == b) {
140-
a = A * A;
178+
a = A *= A;
141179
} else {
142-
a = A * dft<base>(b, n);
180+
a = A *= dft<base>(b, n);
143181
}
144182
}
145183
template<typename base>
146-
void mul(std::vector<base> &a, std::vector<base> const& b) {
147-
if(std::min(a.size(), b.size()) < magic) {
148-
mul_slow(a, b);
149-
return;
150-
}
151-
auto n = com_size(a.size(), b.size());
152-
a.resize(n);
153-
auto B(b);
154-
B.resize(n);
155-
fft(a, n);
156-
fft(B, n);
157-
for(size_t i = 0; i < n; i++) {
158-
a[i] *= B[i];
159-
}
160-
fft(a, n);
161-
reverse(begin(a) + 1, end(a));
162-
for(size_t i = 0; i < n; i++) {
163-
a[i] /= n;
184+
void circular_mul(std::vector<base> &a, std::vector<base> const& b) {
185+
auto n = std::bit_ceil(a.size());
186+
auto A = dft<base>(a, n);
187+
if(a == b) {
188+
a = A *= A;
189+
} else {
190+
a = A *= dft<base>(b, n);
164191
}
165192
}
166193
}

cp-algo/math/poly.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,13 @@ namespace cp_algo::math {
503503
return a * b.reverse();
504504
}
505505

506+
// [x^k] (a semicorr b) = sum_i a{i+k} * b{i}
507+
static poly_t inner_semicorr(poly_t const& a, poly_t const& b) {
508+
auto ta = a.a;
509+
fft::circular_mul(ta, b.reverse().a);
510+
return poly_t(ta).div_xk(b.deg());
511+
}
512+
506513
// [x^k] (a semicorr b) = sum_i a{i+k} * b{i}
507514
static poly_t semicorr(poly_t const& a, poly_t const& b) {
508515
return corr(a, b).div_xk(b.deg());

verify/poly/taylor.test.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// @brief Polynomial Taylor Shift
2+
#define PROBLEM "https://judge.yosupo.jp/problem/polynomial_taylor_shift"
3+
#pragma GCC optimize("Ofast,unroll-loops")
4+
#pragma GCC target("avx2,tune=native")
5+
#include "cp-algo/math/poly.hpp"
6+
#include <bits/stdc++.h>
7+
8+
using namespace std;
9+
using namespace cp_algo::math;
10+
11+
const int mod = 998244353;
12+
using base = modint<mod>;
13+
using polyn = poly_t<base>;
14+
15+
void solve() {
16+
int n, c;
17+
cin >> n >> c;
18+
vector<base> a(n);
19+
copy_n(istream_iterator<base>(cin), n, begin(a));
20+
polyn(a).shift(c).print(n);
21+
}
22+
signed main() {
23+
//freopen("input.txt", "r", stdin);
24+
ios::sync_with_stdio(0);
25+
cin.tie(0);
26+
int t = 1;
27+
//cin >> t;
28+
while(t--) {
29+
solve();
30+
}
31+
}

verify/poly/wildcard.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ string matches(string const& A, string const& B, char wild = '*') {
3030
}
3131
P[i] = coeffs;
3232
}
33-
auto dist0 = polyn::semicorr(P[0], P[1]);
33+
auto dist0 = polyn::inner_semicorr(P[0], P[1]);
3434
string ans(size(ST[0]) - size(ST[1]) + 1, '0');
3535
for(size_t j = 0; j <= size(ans); j++) {
3636
ans[j] = '0' + (

0 commit comments

Comments
 (0)