Skip to content

Commit 60b3d5c

Browse files
committed
Vectorize (kinda) fft, improve wildcard matching
1 parent 98f4c48 commit 60b3d5c

File tree

2 files changed

+180
-106
lines changed

2 files changed

+180
-106
lines changed

cp-algo/math/fft.hpp

Lines changed: 142 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
#include <vector>
99
#include <bit>
1010
namespace cp_algo::math::fft {
11-
using ftype = double;
12-
using point = std::complex<ftype>;
13-
14-
std::vector<point> w; // w[2^n + k] = exp(pi * k / (2^n))
15-
std::vector<int> bitr;// b[2^n + k] = bitreverse(k)
16-
const ftype pi = acos(-1);
17-
bool initiated = 0;
11+
const auto bitr = [](){
12+
std::vector<size_t> bitr(maxn);
13+
for(size_t n = 2; n < maxn; n *= 2) {
14+
for(size_t k = 0; k < n; k++) {
15+
bitr[n + k] = bitr[n + k / 2] / 2 + (k & 1) * (n / 2);
16+
}
17+
}
18+
return bitr;
19+
}();
1820
size_t bitreverse(size_t n, size_t k) {
1921
size_t hn = n / 2;
2022
if(k >= hn && n > 1) {
@@ -23,62 +25,123 @@ namespace cp_algo::math::fft {
2325
return 2 * bitr[hn + k];
2426
}
2527
}
26-
void init() {
27-
if(!initiated) {
28-
w.resize(maxn);
29-
bitr.resize(maxn);
30-
for(int i = 1; i < maxn; i *= 2) {
31-
int ti = i / 2;
32-
ftype arg = pi / i;
33-
point base = std::polar(1., arg);
34-
point cur = 1.;
35-
for(int j = 0; j < i; j++) {
36-
if((j & 15) == 0) {
37-
cur = std::polar(1., j * arg);
38-
}
39-
w[i + j] = cur;
40-
cur *= base;
41-
if(ti) {
42-
bitr[i + j] = 2 * bitr[ti + j % ti] + (j >= ti);
43-
}
44-
}
28+
29+
using ftype = double;
30+
static constexpr size_t bytes = 32;
31+
static constexpr size_t flen = bytes / sizeof(ftype);
32+
using point = std::complex<ftype>;
33+
using vftype [[gnu::vector_size(bytes)]] = ftype;
34+
using vpoint = std::complex<vftype>;
35+
36+
constexpr vftype to_vec(ftype x) {
37+
return vftype{} + x;
38+
}
39+
constexpr vpoint to_vec(point r) {
40+
return {to_vec(r.real()), to_vec(r.imag())};
41+
}
42+
struct cvector {
43+
std::vector<vftype> x, y;
44+
cvector() {}
45+
cvector(size_t n) {
46+
resize(n);
47+
}
48+
void resize(size_t n) {
49+
n = std::bit_ceil(std::max<size_t>(n, 4));
50+
if(size() != n) {
51+
x.resize(n / flen);
52+
y.resize(n / flen);
4553
}
46-
initiated = 1;
4754
}
48-
}
49-
50-
void ifft(auto &a, int n) {
51-
init();
52-
if(n == 1) {
53-
return;
55+
template<class pt = point>
56+
void set(size_t k, pt t) {
57+
if constexpr(std::is_same_v<pt, point>) {
58+
x[k / flen][k % flen] = real(t);
59+
y[k / flen][k % flen] = imag(t);
60+
} else {
61+
x[k / flen] = real(t);
62+
y[k / flen] = imag(t);
63+
}
64+
}
65+
template<class pt = point>
66+
pt get(size_t k) const {
67+
if constexpr(std::is_same_v<pt, point>) {
68+
return {x[k / flen][k % flen], y[k / flen][k % flen]};
69+
} else {
70+
return {x[k / flen], y[k / flen]};
71+
}
72+
}
73+
size_t size() const {
74+
return flen * std::size(x);
5475
}
55-
for(int i = 1; i < n; i *= 2) {
56-
for(int j = 0; j < n; j += 2 * i) {
57-
for(int k = j; k < j + i; k++) {
58-
std::tie(a[k], a[k + i]) = std::pair{
59-
a[k] + a[k + i] * conj(w[i + k - j]),
60-
a[k] - a[k + i] * conj(w[i + k - j])
76+
void dot(cvector const& t) {
77+
size_t n = size();
78+
for(size_t k = 0; k < n; k += flen) {
79+
set(k, get<vpoint>(k) * t.get<vpoint>(k));
80+
}
81+
}
82+
static const cvector roots;
83+
84+
void ifft() {
85+
size_t n = size();
86+
for(size_t i = 1; i < n; i *= 2) {
87+
for(size_t j = 0; j < n; j += 2 * i) {
88+
auto butterfly = [&]<class pt>(pt) {
89+
size_t step = sizeof(pt) / sizeof(point);
90+
for(size_t k = j; k < j + i; k += step) {
91+
auto T = get<pt>(k + i) * conj(roots.get<pt>(i + k - j));
92+
set(k + i, get<pt>(k) - T);
93+
set(k, get<pt>(k) + T);
94+
}
6195
};
96+
if(2 * i <= flen) {
97+
butterfly(point{});
98+
} else {
99+
butterfly(vpoint{});
100+
}
62101
}
63102
}
103+
for(size_t k = 0; k < n; k += flen) {
104+
set(k, get<vpoint>(k) /= to_vec(n));
105+
}
64106
}
65-
}
66-
void fft(auto &a, int n) {
67-
init();
68-
if(n == 1) {
69-
return;
70-
}
71-
for(int i = n / 2; i >= 1; i /= 2) {
72-
for(int j = 0; j < n; j += 2 * i) {
73-
for(int k = j; k < j + i; k++) {
74-
std::tie(a[k], a[k + i]) = std::pair{
75-
a[k] + a[k + i],
76-
(a[k] - a[k + i]) * w[i + k - j]
107+
void fft() {
108+
size_t n = size();
109+
for(size_t i = n / 2; i >= 1; i /= 2) {
110+
for(size_t j = 0; j < n; j += 2 * i) {
111+
auto butterfly = [&]<class pt>(pt) {
112+
size_t step = sizeof(pt) / sizeof(point);
113+
for(size_t k = j; k < j + i; k += step) {
114+
auto A = get<pt>(k) + get<pt>(k + i);
115+
auto B = get<pt>(k) - get<pt>(k + i);
116+
set(k, A);
117+
set(k + i, B * roots.get<pt>(i + k - j));
118+
}
77119
};
120+
if(2 * i <= flen) {
121+
butterfly(point{});
122+
} else {
123+
butterfly(vpoint{});
124+
}
78125
}
79126
}
80127
}
81-
}
128+
};
129+
const cvector cvector::roots = []() {
130+
cvector res(maxn);
131+
for(size_t n = 1; n < maxn; n *= 2) {
132+
auto base = std::polar(1., std::numbers::pi / n);
133+
point cur = 1;
134+
for(size_t k = 0; k < n; k++) {
135+
if((k & 15) == 0) {
136+
cur = std::polar(1., std::numbers::pi * k / n);
137+
}
138+
res.set(n + k, cur);
139+
cur *= base;
140+
}
141+
}
142+
return res;
143+
}();
144+
82145
template<typename base>
83146
void mul_slow(std::vector<base> &a, const std::vector<base> &b) {
84147
if(a.empty() || b.empty()) {
@@ -98,14 +161,14 @@ namespace cp_algo::math::fft {
98161

99162
template<typename base>
100163
struct dft {
101-
std::vector<point> A;
164+
cvector A;
102165

103166
dft(std::vector<base> const& a, size_t n): A(n) {
104167
for(size_t i = 0; i < std::min(n, a.size()); i++) {
105-
A[i] = a[i];
168+
A.set(i, a[i]);
106169
}
107170
if(n) {
108-
fft(A, n);
171+
A.fft();
109172
}
110173
}
111174

@@ -115,39 +178,33 @@ namespace cp_algo::math::fft {
115178
if(!n) {
116179
return std::vector<base>();
117180
}
118-
for(size_t i = 0; i < n; i++) {
119-
A[i] *= B[i];
120-
}
121-
ifft(A, n);
122-
for(size_t i = 0; i < n; i++) {
123-
A[i] /= n;
124-
}
125-
if constexpr (std::is_same_v<base, point>) {
126-
return A;
127-
} else {
128-
return {begin(A), end(A)};
181+
A.dot(B.A);
182+
A.ifft();
183+
std::vector<base> res(n);
184+
for(size_t k = 0; k < n; k++) {
185+
res[k] = A.get(k);
129186
}
187+
return res;
130188
}
131189

132190
auto operator * (dft const& B) const {
133191
return dft(*this) *= B;
134192
}
135193

136-
point& operator [](int i) {return A[i];}
137-
point operator [](int i) const {return A[i];}
194+
point operator [](int i) const {return A.get(i);}
138195
};
139196

140197
template<modint_type base>
141198
struct dft<base> {
142199
static constexpr int split = 1 << 15;
143-
std::vector<point> A;
200+
cvector A;
144201

145202
dft(std::vector<base> const& a, size_t n): A(n) {
146203
for(size_t i = 0; i < std::min(n, a.size()); i++) {
147-
A[i] = point(a[i].rem() % split, a[i].rem() / split);
204+
A.set(i, point{a[i].rem() % split, a[i].rem() / split});
148205
}
149206
if(n) {
150-
fft(A, n);
207+
A.fft();
151208
}
152209
}
153210

@@ -157,26 +214,25 @@ namespace cp_algo::math::fft {
157214
if(!n) {
158215
return std::vector<base>();
159216
}
160-
std::vector<point> C(n);
217+
cvector C(n);
161218
for(size_t i = 0; 2 * i <= n; i++) {
162219
size_t j = (n - i) % n;
163220
size_t x = bitreverse(n, i);
164221
size_t y = bitreverse(n, j);
165-
std::tie(C[x], A[x], C[y], A[y]) = std::make_tuple(
166-
A[x] * (B[x] + conj(B[y])),
167-
A[x] * (B[x] - conj(B[y])),
168-
A[y] * (B[y] + conj(B[x])),
169-
A[y] * (B[y] - conj(B[x]))
170-
);
171-
}
172-
ifft(C, n);
173-
ifft(A, n);
174-
int t = 2 * n;
222+
auto Ax = A.get(x), Bx = B[x];
223+
auto Ay = A.get(y), By = B[y];
224+
C.set(x, Ax * (Bx + conj(By)));
225+
A.set(x, Ax * (Bx - conj(By)));
226+
C.set(y, Ay * (By + conj(Bx)));
227+
A.set(y, Ay * (By - conj(Bx)));
228+
}
229+
A.ifft();
230+
C.ifft();
175231
std::vector<base> res(n);
176232
for(size_t i = 0; i < n; i++) {
177-
base A0 = llround(C[i].real() / t);
178-
base A1 = llround(C[i].imag() / t + A[i].imag() / t);
179-
base A2 = llround(A[i].real() / t);
233+
base A0 = llround(C.get(i).real()) / 2;
234+
base A1 = llround(C.get(i).imag() + A.get(i).imag()) / 2;
235+
base A2 = llround(A.get(i).real()) / 2;
180236
res[i] = A0 + A1 * split - A2 * split * split;
181237
}
182238
return res;
@@ -186,8 +242,7 @@ namespace cp_algo::math::fft {
186242
return dft(*this) *= B;
187243
}
188244

189-
point& operator [](int i) {return A[i];}
190-
point operator [](int i) const {return A[i];}
245+
point operator [](int i) const {return A.get(i);}
191246
};
192247

193248
size_t com_size(size_t as, size_t bs) {

verify/poly/wildcard.test.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,59 @@
33
#pragma GCC optimize("Ofast,unroll-loops")
44
#pragma GCC target("tune=native")
55
#define CP_ALGO_MAXN 1 << 19
6-
#include "cp-algo/math/poly.hpp"
6+
#include "cp-algo/math/fft.hpp"
7+
#include "cp-algo/random/rng.hpp"
78
#include <bits/stdc++.h>
89

910
using namespace std;
1011
using namespace cp_algo::math;
1112

12-
using base = complex<double>;
13-
using polyn = poly_t<base>;
13+
using fft::ftype;
14+
using fft::point;
15+
using fft::cvector;
16+
17+
auto semicorr(auto &a, auto &b) {
18+
b.resize(size(a));
19+
a.fft();
20+
b.fft();
21+
a.dot(b);
22+
a.ifft();
23+
return a;
24+
}
25+
26+
auto is_integer = [](point a) {
27+
static const double eps = 1e-8;
28+
return abs(imag(a)) < eps
29+
&& abs(real(a) - round(real(a))) < eps;
30+
};
1431

1532
string matches(string const& A, string const& B, char wild = '*') {
16-
static base c_to_int[2][26];
33+
static const int sigma = 26;
34+
static point project[2][sigma];
1735
static bool init = false;
1836
if(!init) {
1937
init = true;
20-
for(int i = 0; i < 26; i++) {
21-
c_to_int[0][i] = polar(1., (double)cp_algo::random::rng());
22-
c_to_int[1][i] = conj(c_to_int[0][i]);
38+
for(int i = 0; i < sigma; i++) {
39+
project[0][i] = polar(1., (ftype)cp_algo::random::rng());
40+
project[1][i] = conj(project[0][i]);
2341
}
2442
}
25-
string ST[2] = {A, B};
26-
polyn P[2];
43+
array ST = {&A, &B};
44+
cvector P[2];
2745
for(int i: {0, 1}) {
28-
vector<base> coeffs(size(ST[i]));
29-
for(size_t k = 0; k < size(ST[i]); k++) {
30-
coeffs[k] = base(ST[i][k] == wild ? 0 : c_to_int[i][ST[i][k] - 'a']);
46+
size_t N = size(*ST[i]);
47+
P[i].resize(N);
48+
for(size_t k = 0; k < N; k++) {
49+
char c = ST[i]->at(k);
50+
size_t idx = i ? N - k - 1 : k;
51+
point val = c == wild ? 0 : project[i][c - 'a'];
52+
P[i].set(idx, val);
3153
}
32-
P[i] = coeffs;
3354
}
34-
auto dist0 = polyn::inner_semicorr(P[0], P[1]);
35-
string ans(size(ST[0]) - size(ST[1]) + 1, '0');
36-
for(size_t j = 0; j <= size(ans); j++) {
37-
ans[j] = '0' + (
38-
abs(dist0[j].imag()) < 1e-8 && abs(dist0[j].real() - round(dist0[j].real())) < 1e-8
39-
);
55+
auto corr = semicorr(P[0], P[1]);
56+
string ans(size(A) - size(B) + 1, '0');
57+
for(size_t j = 0; j < size(ans); j++) {
58+
ans[j] = '0' + is_integer(corr.get(size(B) - 1 + j));
4059
}
4160
return ans;
4261
}

0 commit comments

Comments
 (0)