Skip to content

Commit 3a3ed82

Browse files
committed
Reduce roots of unity precalc
1 parent 36f1160 commit 3a3ed82

File tree

2 files changed

+60
-67
lines changed

2 files changed

+60
-67
lines changed

cp-algo/math/fft.hpp

Lines changed: 60 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,31 @@
77
#include <cassert>
88
#include <vector>
99
#include <bit>
10-
namespace cp_algo::math::fft {
11-
const auto bitr = [](){
12-
std::vector<size_t> bitr(maxn);
13-
for(size_t n = 2; n < maxn; n *= 2) {
14-
for(size_t k = 0; k < n; k++) {
15-
bitr[n + k] = bitr[n + k / 2] / 2 + (k & 1) * (n / 2);
16-
}
17-
}
18-
return bitr;
19-
}();
20-
size_t bitreverse(size_t n, size_t k) {
21-
size_t hn = n / 2;
22-
if(k >= hn && n > 1) {
23-
return 2 * bitr[k] + 1;
24-
} else {
25-
return 2 * bitr[hn + k];
26-
}
27-
}
2810

11+
namespace cp_algo::math::fft {
2912
using ftype = double;
3013
static constexpr size_t bytes = 32;
3114
static constexpr size_t flen = bytes / sizeof(ftype);
3215
using point = std::complex<ftype>;
3316
using vftype [[gnu::vector_size(bytes)]] = ftype;
3417
using vpoint = std::complex<vftype>;
3518

36-
constexpr vftype to_vec(ftype x) {
37-
return vftype{} + x;
19+
#define WITH_IV(...) \
20+
[&]<size_t ... i>(std::index_sequence<i...>) { \
21+
return __VA_ARGS__; \
22+
}(std::make_index_sequence<flen>());
23+
24+
template<typename ft>
25+
constexpr ft to_ft(auto x) {
26+
return ft{} + x;
3827
}
39-
constexpr vpoint to_vec(point r) {
40-
return {to_vec(r.real()), to_vec(r.imag())};
28+
template<typename pt>
29+
constexpr pt to_pt(point r) {
30+
using ft = std::conditional_t<std::is_same_v<point, pt>, ftype, vftype>;
31+
return {to_ft<ft>(r.real()), to_ft<ft>(r.imag())};
4132
}
4233
struct cvector {
34+
static constexpr size_t pre_roots = 1 << 17;
4335
std::vector<vftype> x, y;
4436
cvector() {}
4537
cvector(size_t n) {
@@ -84,54 +76,74 @@ namespace cp_algo::math::fft {
8476
}
8577
}
8678
static const cvector roots;
79+
template<class pt = point>
80+
static pt root(size_t n, size_t k) {
81+
if(n < pre_roots) {
82+
return cvector::roots.get<pt>(n + k);
83+
} else {
84+
auto arg = std::numbers::pi / n;
85+
if constexpr(std::is_same_v<pt, point>) {
86+
return {cos(k * arg), sin(k * arg)};
87+
} else {
88+
return WITH_IV(pt{vftype{cos((k + i) * arg)...},
89+
vftype{sin((k + i) * arg)...}});
90+
}
91+
}
92+
}
93+
template<class pt = point>
94+
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
95+
size_t step = sizeof(pt) / sizeof(point);
96+
pt cur;
97+
pt arg = to_pt<pt>(root<point>(n, step));
98+
for(size_t i = 0; i < m; i += step) {
99+
cur = (i & 63) == 0 || n < pre_roots ? root<pt>(n, i) : cur * arg;
100+
callback(i, cur);
101+
}
102+
}
87103

88104
void ifft() {
89105
size_t n = size();
90106
for(size_t i = 1; i < n; i *= 2) {
91107
for(size_t j = 0; j < n; j += 2 * i) {
92-
auto butterfly = [&]<class pt>(pt) {
93-
size_t step = sizeof(pt) / sizeof(point);
94-
for(size_t k = j; k < j + i; k += step) {
95-
auto T = get<pt>(k + i) * conj(roots.get<pt>(i + k - j));
96-
set(k + i, get<pt>(k) - T);
97-
set(k, get<pt>(k) + T);
98-
}
108+
auto butterfly = [&]<class pt>(size_t k, pt rt) {
109+
k += j;
110+
auto t = get<pt>(k + i) * conj(rt);
111+
set(k + i, get<pt>(k) - t);
112+
set(k, get<pt>(k) + t);
99113
};
100114
if(2 * i <= flen) {
101-
butterfly(point{});
115+
exec_on_roots(i, i, butterfly);
102116
} else {
103-
butterfly(vpoint{});
117+
exec_on_roots<vpoint>(i, i, butterfly);
104118
}
105119
}
106120
}
107121
for(size_t k = 0; k < n; k += flen) {
108-
set(k, get<vpoint>(k) /= to_vec(n));
122+
set(k, get<vpoint>(k) /= to_pt<vpoint>(n));
109123
}
110124
}
111125
void fft() {
112126
size_t n = size();
113127
for(size_t i = n / 2; i >= 1; i /= 2) {
114128
for(size_t j = 0; j < n; j += 2 * i) {
115-
auto butterfly = [&]<class pt>(pt) {
116-
size_t step = sizeof(pt) / sizeof(point);
117-
for(size_t k = j; k < j + i; k += step) {
118-
auto A = get<pt>(k) + get<pt>(k + i);
119-
auto B = get<pt>(k) - get<pt>(k + i);
120-
set(k, A);
121-
set(k + i, B * roots.get<pt>(i + k - j));
122-
}
129+
auto butterfly = [&]<class pt>(size_t k, pt rt) {
130+
k += j;
131+
auto A = get<pt>(k) + get<pt>(k + i);
132+
auto B = get<pt>(k) - get<pt>(k + i);
133+
set(k, A);
134+
set(k + i, B * rt);
123135
};
124136
if(2 * i <= flen) {
125-
butterfly(point{});
137+
exec_on_roots(i, i, butterfly);
126138
} else {
127-
butterfly(vpoint{});
139+
exec_on_roots<vpoint>(i, i, butterfly);
128140
}
129141
}
130142
}
131143
}
132144
};
133145
const cvector cvector::roots = []() {
134-
cvector res(2 * maxn);
146+
cvector res(pre_roots);
135147
for(size_t n = 1; n < res.size(); n *= 2) {
136148
auto base = std::polar(1., std::numbers::pi / n);
137149
point cur = 1;
@@ -145,15 +157,6 @@ namespace cp_algo::math::fft {
145157
}
146158
return res;
147159
}();
148-
point root(size_t n, size_t k) {
149-
if(n < maxn) {
150-
return cvector::roots.get(n + k);
151-
} else if(k % 2 == 0) {
152-
return root(n / 2, k / 2);
153-
} else {
154-
return std::polar(1., std::numbers::pi * k / n);
155-
}
156-
}
157160

158161
template<typename base>
159162
void mul_slow(std::vector<base> &a, const std::vector<base> &b) {
@@ -171,7 +174,7 @@ namespace cp_algo::math::fft {
171174
}
172175
}
173176
}
174-
177+
175178
template<typename base>
176179
struct dft {
177180
cvector A;
@@ -184,7 +187,7 @@ namespace cp_algo::math::fft {
184187
A.fft();
185188
}
186189
}
187-
190+
188191
std::vector<base> operator *= (dft const& B) {
189192
assert(A.size() == B.A.size());
190193
size_t n = A.size();
@@ -203,26 +206,17 @@ namespace cp_algo::math::fft {
203206
auto operator * (dft const& B) const {
204207
return dft(*this) *= B;
205208
}
206-
209+
207210
point operator [](int i) const {return A.get(i);}
208211
};
209212

210213
template<modint_type base>
211214
struct dft<base> {
212215
static constexpr int split = 1 << 15;
213216
cvector A, B;
214-
215-
void exec_on_roots(size_t n, size_t m, auto &&callback) {
216-
point cur = 1;
217-
point step = root(n, 1);
218-
for(size_t i = 0; i < m; i++) {
219-
callback(i, cur);
220-
cur = (i & 15) == 0 || 2 * n < maxn ? root(n, i + 1) : cur * step;
221-
}
222-
}
223217

224218
dft(std::vector<base> const& a, size_t n): A(n), B(n) {
225-
exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
219+
cvector::exec_on_roots(2 * n, size(a), [&](size_t i, point rt) {
226220
A.set(i % n, A.get(i % n) + ftype(a[i].rem() % split) * rt);
227221
B.set(i % n, B.get(i % n) + ftype(a[i].rem() / split) * rt);
228222

@@ -249,7 +243,7 @@ namespace cp_algo::math::fft {
249243
B.ifft();
250244
C.ifft();
251245
std::vector<base> res(2 * n);
252-
exec_on_roots(2 * n, n, [&](size_t i, point rt) {
246+
cvector::exec_on_roots(2 * n, n, [&](size_t i, point rt) {
253247
rt = conj(rt);
254248
auto Ai = A.get(i) * rt;
255249
auto Bi = B.get(i) * rt;

verify/poly/wildcard.test.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define PROBLEM "https://judge.yosupo.jp/problem/wildcard_pattern_matching"
33
#pragma GCC optimize("Ofast,unroll-loops")
44
#pragma GCC target("tune=native")
5-
#define CP_ALGO_MAXN 1 << 19
65
#include "cp-algo/math/fft.hpp"
76
#include "cp-algo/random/rng.hpp"
87
#include <bits/stdc++.h>

0 commit comments

Comments
 (0)