Skip to content

Commit 603b474

Browse files
committed
Basic support for poly<complex<double>>, improved wildcard matching
1 parent ebe9752 commit 603b474

File tree

3 files changed

+51
-20
lines changed

3 files changed

+51
-20
lines changed

cp-algo/math/fft.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,26 @@ namespace cp_algo::math::fft {
142142
a = A * dft<base>(b, n);
143143
}
144144
}
145+
template<typename base>
146+
void mul(std::vector<base> &a, std::vector<base> const& b) {
147+
if(std::min(a.size(), b.size()) < magic) {
148+
mul_slow(a, b);
149+
return;
150+
}
151+
auto n = com_size(a.size(), b.size());
152+
a.resize(n);
153+
auto B(b);
154+
B.resize(n);
155+
fft(a, n);
156+
fft(B, n);
157+
for(size_t i = 0; i < n; i++) {
158+
a[i] *= B[i];
159+
}
160+
fft(a, n);
161+
reverse(begin(a) + 1, end(a));
162+
for(size_t i = 0; i < n; i++) {
163+
a[i] /= n;
164+
}
165+
}
145166
}
146167
#endif // CP_ALGO_MATH_FFT_HPP

cp-algo/math/poly/impl/base.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
#include <iostream>
66
// really basic operations, typically taking O(n)
77
namespace cp_algo::math::poly::impl {
8-
void normalize(auto& p) {
9-
while(p.deg() >= 0 && p.lead() == 0) {
8+
template<typename polyn>
9+
void normalize(polyn& p) {
10+
while(p.deg() >= 0 && p.lead() == typename polyn::base(0)) {
1011
p.a.pop_back();
1112
}
1213
}

verify/poly/wildcard.test.cpp

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,41 @@
88
using namespace std;
99
using namespace cp_algo::math;
1010

11-
const int mod = 1e9 + 9;
12-
13-
using base = modint<mod>;
11+
using base = complex<double>;
1412
using polyn = poly_t<base>;
1513

16-
void solve() {
17-
string ST[2];
18-
cin >> ST[0] >> ST[1];
19-
polyn P[2][3];
14+
string matches(string const& A, string const& B, char wild = '*') {
15+
static base c_to_int[2][26];
16+
static bool init = false;
17+
if(!init) {
18+
for(int i = 0; i < 26; i++) {
19+
c_to_int[0][i] = polar(1., (double)cp_algo::random::rng());
20+
c_to_int[1][i] = conj(c_to_int[0][i]);
21+
}
22+
}
23+
string ST[2] = {A, B};
24+
polyn P[2];
2025
for(int i: {0, 1}) {
21-
for(int j: {1, 2, 3}) {
22-
vector<base> coeffs(size(ST[i]));
23-
for(size_t k = 0; k < size(ST[i]); k++) {
24-
coeffs[k] = bpow(base(ST[i][k] == '*' ? 0 : ST[i][k] - 'a' + 1), j);
25-
}
26-
P[i][j-1] = coeffs;
26+
vector<base> coeffs(size(ST[i]));
27+
for(size_t k = 0; k < size(ST[i]); k++) {
28+
coeffs[k] = base(ST[i][k] == wild ? 0 : c_to_int[i][ST[i][k] - 'a']);
2729
}
30+
P[i] = coeffs;
2831
}
29-
auto dist = polyn::semicorr(P[0][0], P[1][2])
30-
+ polyn::semicorr(P[0][2], P[1][0])
31-
- 2*polyn::semicorr(P[0][1], P[1][1]);
32+
auto dist0 = polyn::semicorr(P[0], P[1]);
3233
string ans(size(ST[0]) - size(ST[1]) + 1, '0');
3334
for(size_t j = 0; j <= size(ans); j++) {
34-
ans[j] = '0' + int(dist[j] == 0);
35+
ans[j] = '0' + (
36+
abs(dist0[j].imag()) < 1e-4 && abs(dist0[j].real() - round(dist0[j].real())) < 1e-4
37+
);
3538
}
36-
cout << ans << "\n";
39+
return ans;
40+
}
41+
42+
void solve() {
43+
string a, b;
44+
cin >> a >> b;
45+
cout << matches(a, b) << "\n";
3746
}
3847

3948
signed main() {

0 commit comments

Comments
 (0)