6
6
#include < complex>
7
7
#include < cassert>
8
8
#include < vector>
9
+ #include < bit>
9
10
namespace cp_algo ::math::fft {
10
11
using ftype = double ;
11
12
using point = std::complex<ftype>;
@@ -70,48 +71,89 @@ namespace cp_algo::math::fft {
70
71
}
71
72
}
72
73
73
- template <modint_type base>
74
+ template <typename base>
74
75
struct dft {
76
+ std::vector<point> A;
77
+
78
+ dft (std::vector<base> const & a, size_t n): A(n) {
79
+ for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
80
+ A[i] = a[i];
81
+ }
82
+ if (n) {
83
+ fft (A, n);
84
+ }
85
+ }
86
+
87
+ auto operator *= (dft const & B) {
88
+ assert (A.size () == B.A .size ());
89
+ size_t n = A.size ();
90
+ if (!n) {
91
+ return std::vector<base>();
92
+ }
93
+ for (size_t i = 0 ; i < n; i++) {
94
+ A[i] *= B[i];
95
+ }
96
+ fft (A, n);
97
+ reverse (begin (A) + 1 , end (A));
98
+ std::vector<base> res (n);
99
+ for (size_t i = 0 ; i < n; i++) {
100
+ res[i] = A[i];
101
+ res[i] /= n;
102
+ }
103
+ return res;
104
+ }
105
+
106
+ auto operator * (dft const & B) const {
107
+ return dft (*this ) *= B;
108
+ }
109
+
110
+ point& operator [](int i) {return A[i];}
111
+ point operator [](int i) const {return A[i];}
112
+ };
113
+
114
+ template <modint_type base>
115
+ struct dft <base> {
75
116
static constexpr int split = 1 << 15 ;
76
117
std::vector<point> A;
77
118
78
119
dft (std::vector<base> const & a, size_t n): A(n) {
79
120
for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
80
- A[i] = point (
81
- a[i].rem () % split,
82
- a[i].rem () / split
83
- );
121
+ A[i] = point (a[i].rem () % split, a[i].rem () / split);
84
122
}
85
123
if (n) {
86
124
fft (A, n);
87
125
}
88
126
}
89
127
90
- auto operator * (dft const & B) {
128
+ auto operator *= (dft const & B) {
91
129
assert (A.size () == B.A .size ());
92
130
size_t n = A.size ();
93
131
if (!n) {
94
132
return std::vector<base>();
95
133
}
96
- std::vector<point> C (n), D (n) ;
134
+ std::vector<point> C (n);
97
135
for (size_t i = 0 ; i < n; i++) {
98
136
C[i] = A[i] * (B[i] + conj (B[(n - i) % n]));
99
- D [i] = A[i] * (B[i] - conj (B[(n - i) % n]));
137
+ A [i] = A[i] * (B[i] - conj (B[(n - i) % n]));
100
138
}
101
139
fft (C, n);
102
- fft (D , n);
140
+ fft (A , n);
103
141
reverse (begin (C) + 1 , end (C));
104
- reverse (begin (D ) + 1 , end (D ));
142
+ reverse (begin (A ) + 1 , end (A ));
105
143
int t = 2 * n;
106
144
std::vector<base> res (n);
107
145
for (size_t i = 0 ; i < n; i++) {
108
146
base A0 = llround (C[i].real () / t);
109
- base A1 = llround (C[i].imag () / t + D [i].imag () / t);
110
- base A2 = llround (D [i].real () / t);
147
+ base A1 = llround (C[i].imag () / t + A [i].imag () / t);
148
+ base A2 = llround (A [i].real () / t);
111
149
res[i] = A0 + A1 * split - A2 * split * split;
112
150
}
113
151
return res;
114
152
}
153
+
154
+ auto operator * (dft const & B) const {
155
+ return dft (*this ) *= B;
156
+ }
115
157
116
158
point& operator [](int i) {return A[i];}
117
159
point operator [](int i) const {return A[i];}
@@ -121,14 +163,10 @@ namespace cp_algo::math::fft {
121
163
if (!as || !bs) {
122
164
return 0 ;
123
165
}
124
- size_t n = as + bs - 1 ;
125
- while (__builtin_popcount (n) != 1 ) {
126
- n++;
127
- }
128
- return n;
166
+ return std::bit_ceil (as + bs - 1 );
129
167
}
130
168
131
- template <modint_type base>
169
+ template <typename base>
132
170
void mul (std::vector<base> &a, std::vector<base> const & b) {
133
171
if (std::min (a.size (), b.size ()) < magic) {
134
172
mul_slow (a, b);
@@ -137,30 +175,19 @@ namespace cp_algo::math::fft {
137
175
auto n = com_size (a.size (), b.size ());
138
176
auto A = dft<base>(a, n);
139
177
if (a == b) {
140
- a = A * A;
178
+ a = A *= A;
141
179
} else {
142
- a = A * dft<base>(b, n);
180
+ a = A *= dft<base>(b, n);
143
181
}
144
182
}
145
183
template <typename base>
146
- void mul (std::vector<base> &a, std::vector<base> const & b) {
147
- if (std::min (a.size (), b.size ()) < magic) {
148
- mul_slow (a, b);
149
- return ;
150
- }
151
- auto n = com_size (a.size (), b.size ());
152
- a.resize (n);
153
- auto B (b);
154
- B.resize (n);
155
- fft (a, n);
156
- fft (B, n);
157
- for (size_t i = 0 ; i < n; i++) {
158
- a[i] *= B[i];
159
- }
160
- fft (a, n);
161
- reverse (begin (a) + 1 , end (a));
162
- for (size_t i = 0 ; i < n; i++) {
163
- a[i] /= n;
184
+ void circular_mul (std::vector<base> &a, std::vector<base> const & b) {
185
+ auto n = std::bit_ceil (a.size ());
186
+ auto A = dft<base>(a, n);
187
+ if (a == b) {
188
+ a = A *= A;
189
+ } else {
190
+ a = A *= dft<base>(b, n);
164
191
}
165
192
}
166
193
}
0 commit comments