@@ -145,6 +145,15 @@ namespace cp_algo::math::fft {
145
145
}
146
146
return res;
147
147
}();
148
+ point root (size_t n, size_t k) {
149
+ if (n < maxn) {
150
+ return cvector::roots.get (n + k);
151
+ } else if (k % 2 == 0 ) {
152
+ return root (n / 2 , k / 2 );
153
+ } else {
154
+ return std::polar (1 ., std::numbers::pi * k / n);
155
+ }
156
+ }
148
157
149
158
template <typename base>
150
159
void mul_slow (std::vector<base> &a, const std::vector<base> &b) {
@@ -202,12 +211,22 @@ namespace cp_algo::math::fft {
202
211
struct dft <base> {
203
212
static constexpr int split = 1 << 15 ;
204
213
cvector A, B;
214
+
215
+ void exec_on_roots (size_t n, size_t m, auto &&callback) {
216
+ point cur = 1 ;
217
+ point step = root (n, 1 );
218
+ for (size_t i = 0 ; i < m; i++) {
219
+ callback (i, cur);
220
+ cur = (i & 15 ) == 0 || 2 * n < maxn ? root (n, i + 1 ) : cur * step;
221
+ }
222
+ }
205
223
206
224
dft (std::vector<base> const & a, size_t n): A(n), B(n) {
207
- for (size_t i = 0 ; i < size (a); i++) {
208
- A.set (i % n, A.get (i % n) + ftype (a[i].rem () % split) * cvector::roots.get (2 * n + i));
209
- B.set (i % n, B.get (i % n) + ftype (a[i].rem () / split) * cvector::roots.get (2 * n + i));
210
- }
225
+ exec_on_roots (2 * n, size (a), [&](size_t i, point rt) {
226
+ A.set (i % n, A.get (i % n) + ftype (a[i].rem () % split) * rt);
227
+ B.set (i % n, B.get (i % n) + ftype (a[i].rem () / split) * rt);
228
+
229
+ });
211
230
if (n) {
212
231
A.fft ();
213
232
B.fft ();
@@ -230,10 +249,11 @@ namespace cp_algo::math::fft {
230
249
B.ifft ();
231
250
C.ifft ();
232
251
std::vector<base> res (2 * n);
233
- for (size_t i = 0 ; i < n; i++) {
234
- auto Ai = A.get (i) * conj (cvector::roots.get (2 * n + i));
235
- auto Bi = B.get (i) * conj (cvector::roots.get (2 * n + i));
236
- auto Ci = C.get (i) * conj (cvector::roots.get (2 * n + i));
252
+ exec_on_roots (2 * n, n, [&](size_t i, point rt) {
253
+ rt = conj (rt);
254
+ auto Ai = A.get (i) * rt;
255
+ auto Bi = B.get (i) * rt;
256
+ auto Ci = C.get (i) * rt;
237
257
base A0 = llround (real (Ai));
238
258
base A1 = llround (real (Ci));
239
259
base A2 = llround (real (Bi));
@@ -242,7 +262,7 @@ namespace cp_algo::math::fft {
242
262
base B1 = llround (imag (Ci));
243
263
base B2 = llround (imag (Bi));
244
264
res[n + i] = B0 + B1 * split + B2 * split * split;
245
- }
265
+ });
246
266
return res;
247
267
}
248
268
std::vector<base> operator *= (auto &&B) {
0 commit comments