@@ -84,7 +84,7 @@ namespace cp_algo::math::fft {
84
84
}
85
85
}
86
86
87
- auto operator *= (dft const & B) {
87
+ std::vector<base> operator *= (dft const & B) {
88
88
assert (A.size () == B.A .size ());
89
89
size_t n = A.size ();
90
90
if (!n) {
@@ -95,12 +95,14 @@ namespace cp_algo::math::fft {
95
95
}
96
96
fft (A, n);
97
97
reverse (begin (A) + 1 , end (A));
98
- std::vector<base> res (n);
99
98
for (size_t i = 0 ; i < n; i++) {
100
- res[i] = A[i];
101
- res[i] /= n;
99
+ A[i] /= n;
100
+ }
101
+ if constexpr (std::is_same_v<base, point>) {
102
+ return A;
103
+ } else {
104
+ return {begin (A), end (A)};
102
105
}
103
- return res;
104
106
}
105
107
106
108
auto operator * (dft const & B) const {
@@ -125,16 +127,22 @@ namespace cp_algo::math::fft {
125
127
}
126
128
}
127
129
128
- auto operator *= (dft const & B) {
130
+ std::vector<base> operator *= (dft const & B) {
129
131
assert (A.size () == B.A .size ());
130
132
size_t n = A.size ();
131
133
if (!n) {
132
134
return std::vector<base>();
133
135
}
134
136
std::vector<point> C (n);
135
- for (size_t i = 0 ; i < n; i++) {
136
- C[i] = A[i] * (B[i] + conj (B[(n - i) % n]));
137
- A[i] = A[i] * (B[i] - conj (B[(n - i) % n]));
137
+ for (size_t i = 0 ; 2 * i <= n; i++) {
138
+ int x = i;
139
+ int y = (n - i) % n;
140
+ std::tie (C[x], A[x], C[y], A[y]) = std::make_tuple (
141
+ A[x] * (B[x] + conj (B[y])),
142
+ A[x] * (B[x] - conj (B[y])),
143
+ A[y] * (B[y] + conj (B[x])),
144
+ A[y] * (B[y] - conj (B[x]))
145
+ );
138
146
}
139
147
fft (C, n);
140
148
fft (A, n);
0 commit comments