22#define CP_ALGO_MATH_FFT_HPP
33#include " ../number_theory/modint.hpp"
44#include " ../util/checkpoint.hpp"
5+ #include " ../random/rng.hpp"
56#include " cvector.hpp"
6- #include < ranges>
77#include < iostream>
8+ #include < ranges>
89namespace cp_algo ::math::fft {
910 template <modint_type base>
1011 struct dft {
1112 int split;
1213 cvector A, B;
13-
14+ static base factor, ifactor;
15+ static bool init;
16+
1417 dft (auto const & a, size_t n): A(n), B(n) {
18+ if (!init) {
19+ factor = 1 + random::rng () % (base::mod () - 1 );
20+ ifactor = base (1 ) / factor;
21+ init = true ;
22+ }
1523 split = int (std::sqrt (base::mod ())) + 1 ;
24+ base cur = 1 ;
1625 cvector::exec_on_roots (2 * n, std::min (n, size (a)), [&](size_t i, auto rt) {
1726 auto splt = [&](size_t i) {
27+ #ifdef CP_ALGO_FFT_RANDOMIZER
28+ auto ai = ftype (i < size (a) ? (a[i] * cur).rem () : 0 );
29+ cur *= factor;
30+ #else
1831 auto ai = ftype (i < size (a) ? a[i].rem () : 0 );
32+ #endif
1933 auto rem = std::remainder (ai, split);
2034 auto quo = (ai - rem) / split;
2135 return std::pair{rem, quo};
@@ -32,7 +46,7 @@ namespace cp_algo::math::fft {
3246 }
3347 }
3448
35- void mul (auto &&C, auto const & D, auto &res, size_t k) {
49+ void mul (auto &&C, auto const & D, auto &res, size_t k, [[maybe_unused]] base ifactor ) {
3650 assert (A.size () == C.size ());
3751 size_t n = A.size ();
3852 if (!n) {
@@ -73,6 +87,8 @@ namespace cp_algo::math::fft {
7387 B.ifft ();
7488 C.ifft ();
7589 auto splitsplit = (base (split) * split).rem ();
90+ base cur = 1 ;
91+ base step = bpow (ifactor, n);
7692 cvector::exec_on_roots (2 * n, std::min (n, k), [&](size_t i, point rt) {
7793 rt = conj (rt);
7894 auto Ai = A.get (i) * rt;
@@ -82,21 +98,28 @@ namespace cp_algo::math::fft {
8298 int64_t A1 = llround (real (Ci));
8399 int64_t A2 = llround (real (Bi));
84100 res[i] = A0 + A1 * split + A2 * splitsplit;
101+ #ifdef CP_ALGO_FFT_RANDOMIZER
102+ res[i] *= cur;
103+ #endif
85104 if (n + i >= k) {
86105 return ;
87106 }
88107 int64_t B0 = llround (imag (Ai));
89108 int64_t B1 = llround (imag (Ci));
90109 int64_t B2 = llround (imag (Bi));
91110 res[n + i] = B0 + B1 * split + B2 * splitsplit;
111+ #ifdef CP_ALGO_FFT_RANDOMIZER
112+ res[n + i] *= cur * step;
113+ cur *= ifactor;
114+ #endif
92115 });
93116 checkpoint (" recover mod" );
94117 }
95118 void mul_inplace (auto &&B, auto & res, size_t k) {
96- mul (B.A , B.B , res, k);
119+ mul (B.A , B.B , res, k, ifactor * B. ifactor );
97120 }
98121 void mul (auto const & B, auto & res, size_t k) {
99- mul (cvector (B.A ), B.B , res, k);
122+ mul (cvector (B.A ), B.B , res, k, ifactor * B. ifactor );
100123 }
101124 std::vector<base> operator *= (dft &B) {
102125 std::vector<base> res (2 * A.size ());
@@ -111,9 +134,12 @@ namespace cp_algo::math::fft {
111134 auto operator * (dft const & B) const {
112135 return dft (*this ) *= B;
113136 }
114-
137+
115138 point operator [](int i) const {return A.get (i);}
116139 };
140+ template <modint_type base> base dft<base>::factor = 1 ;
141+ template <modint_type base> base dft<base>::ifactor = 1 ;
142+ template <modint_type base> bool dft<base>::init = false ;
117143
118144 void mul_slow (auto &a, auto const & b, size_t k) {
119145 if (empty (a) || empty (b)) {
@@ -155,8 +181,36 @@ namespace cp_algo::math::fft {
155181 }
156182 }
157183 void mul (auto &a, auto const & b) {
158- if (size (a)) {
159- mul_truncate (a, b, size (a) + size (b) - 1 );
184+ size_t N = size (a) + size (b) - 1 ;
185+ if (std::max (size (a), size (b)) > (1 << 23 )) {
186+ // do karatsuba to save memory
187+ auto n = (std::max (size (a), size (b)) + 1 ) / 2 ;
188+ auto a0 = to<std::vector>(a | std::views::take (n));
189+ auto a1 = to<std::vector>(a | std::views::drop (n));
190+ auto b0 = to<std::vector>(b | std::views::take (n));
191+ auto b1 = to<std::vector>(b | std::views::drop (n));
192+ a0.resize (n); a1.resize (n);
193+ b0.resize (n); b1.resize (n);
194+ auto a01 = to<std::vector>(std::views::zip_transform (std::plus{}, a0, a1));
195+ auto b01 = to<std::vector>(std::views::zip_transform (std::plus{}, b0, b1));
196+ mul (a0, b0);
197+ mul (a1, b1);
198+ mul (a01, b01);
199+ a.assign (4 * n, 0 );
200+ for (auto [i, ai]: a0 | std::views::enumerate) {
201+ a[i] += ai;
202+ a[i + n] -= ai;
203+ }
204+ for (auto [i, ai]: a1 | std::views::enumerate) {
205+ a[i + n] -= ai;
206+ a[i + 2 * n] += ai;
207+ }
208+ for (auto [i, ai]: a01 | std::views::enumerate) {
209+ a[i + n] += ai;
210+ }
211+ a.resize (N);
212+ } else if (size (a)) {
213+ mul_truncate (a, b, N);
160214 }
161215 }
162216}
0 commit comments