Skip to content

Commit 1ad82bb

Browse files
committed
(almost) bitreverseless fft
1 parent 6ae8a20 commit 1ad82bb

File tree

1 file changed

+37
-21
lines changed

1 file changed

+37
-21
lines changed

cp-algo/math/fft.hpp

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ namespace cp_algo::math::fft {
1515
std::vector<int> bitr;// b[2^n + k] = bitreverse(k)
1616
const ftype pi = acos(-1);
1717
bool initiated = 0;
18+
size_t bitreverse(size_t n, size_t k) {
19+
size_t hn = n / 2;
20+
if(k >= hn) {
21+
return 2 * bitr[k] + 1;
22+
} else {
23+
return 2 * bitr[hn + k];
24+
}
25+
}
1826
void init() {
1927
if(!initiated) {
2028
w.resize(maxn);
@@ -32,24 +40,34 @@ namespace cp_algo::math::fft {
3240
}
3341
}
3442

35-
void fft(auto &a, int n) {
43+
void ifft(auto &a, int n) {
3644
init();
3745
if(n == 1) {
3846
return;
3947
}
40-
int hn = n / 2;
41-
for(int i = 0; i < n; i++) {
42-
int ti = 2 * bitr[hn + i % hn] + (i > hn);
43-
if(i < ti) {
44-
std::swap(a[i], a[ti]);
48+
for(int i = 1; i < n; i *= 2) {
49+
for(int j = 0; j < n; j += 2 * i) {
50+
for(int k = j; k < j + i; k++) {
51+
std::tie(a[k], a[k + i]) = std::pair{
52+
a[k] + a[k + i] * conj(w[i + k - j]),
53+
a[k] - a[k + i] * conj(w[i + k - j])
54+
};
55+
}
4556
}
4657
}
47-
for(int i = 1; i < n; i *= 2) {
58+
}
59+
void fft(auto &a, int n) {
60+
init();
61+
if(n == 1) {
62+
return;
63+
}
64+
for(int i = n / 2; i >= 1; i /= 2) {
4865
for(int j = 0; j < n; j += 2 * i) {
4966
for(int k = j; k < j + i; k++) {
50-
point t = a[k + i] * w[i + k - j];
51-
a[k + i] = a[k] - t;
52-
a[k] += t;
67+
std::tie(a[k], a[k + i]) = std::pair{
68+
a[k] + a[k + i],
69+
(a[k] - a[k + i]) * w[i + k - j]
70+
};
5371
}
5472
}
5573
}
@@ -93,8 +111,7 @@ namespace cp_algo::math::fft {
93111
for(size_t i = 0; i < n; i++) {
94112
A[i] *= B[i];
95113
}
96-
fft(A, n);
97-
reverse(begin(A) + 1, end(A));
114+
ifft(A, n);
98115
for(size_t i = 0; i < n; i++) {
99116
A[i] /= n;
100117
}
@@ -126,28 +143,27 @@ namespace cp_algo::math::fft {
126143
fft(A, n);
127144
}
128145
}
129-
130-
std::vector<base> operator *= (dft const& B) {
146+
147+
std::vector<base> operator *= (dft B) {
131148
assert(A.size() == B.A.size());
132149
size_t n = A.size();
133150
if(!n) {
134151
return std::vector<base>();
135152
}
136153
std::vector<point> C(n);
137-
for(size_t i = 0; 2 * i <= n; i++) {
138-
int x = i;
139-
int y = (n - i) % n;
154+
for(size_t i = 0; 2 * i <= n; i++) {//
155+
size_t j = (n - i) % n;
156+
size_t x = bitreverse(n, i);
157+
size_t y = bitreverse(n, j);
140158
std::tie(C[x], A[x], C[y], A[y]) = std::make_tuple(
141159
A[x] * (B[x] + conj(B[y])),
142160
A[x] * (B[x] - conj(B[y])),
143161
A[y] * (B[y] + conj(B[x])),
144162
A[y] * (B[y] - conj(B[x]))
145163
);
146164
}
147-
fft(C, n);
148-
fft(A, n);
149-
reverse(begin(C) + 1, end(C));
150-
reverse(begin(A) + 1, end(A));
165+
ifft(C, n);
166+
ifft(A, n);
151167
int t = 2 * n;
152168
std::vector<base> res(n);
153169
for(size_t i = 0; i < n; i++) {

0 commit comments

Comments
 (0)