99namespace cp_algo ::math::fft {
1010 template <modint_type base>
1111 struct dft {
12- int split;
1312 cvector A, B;
1413 static base factor, ifactor;
1514 static bool init;
15+ static int split;
1616
1717 dft (auto const & a, size_t n): A(n), B(n) {
1818 if (!init) {
1919 factor = 1 + random::rng () % (base::mod () - 1 );
20+ split = int (std::sqrt (base::mod ())) + 1 ;
2021 ifactor = base (1 ) / factor;
2122 init = true ;
2223 }
23- split = int ( std::sqrt ( base::mod ())) + 1 ;
24- base cur = 1 ;
24+ base cur = factor ;
25+ base step = bpow (factor, n) ;
2526 cvector::exec_on_roots (2 * n, std::min (n, size (a)), [&](size_t i, auto rt) {
26- auto splt = [&](size_t i) {
27- auto ai = ftype (i < size (a) ? (a[i] * cur ).rem () : 0 );
27+ auto splt = [&](size_t i, auto mul ) {
28+ auto ai = ftype (i < size (a) ? (a[i] * mul ).rem () : 0 );
2829 auto rem = std::remainder (ai, split);
2930 auto quo = (ai - rem) / split;
3031 return std::pair{rem, quo};
3132 };
32- auto [rai, qai] = splt (i);
33- auto [rani, qani] = splt (n + i);
33+ auto [rai, qai] = splt (i, cur );
34+ auto [rani, qani] = splt (n + i, cur * step );
3435 A.set (i, point (rai, rani) * rt);
3536 B.set (i, point (qai, qani) * rt);
3637 cur *= factor;
@@ -42,7 +43,7 @@ namespace cp_algo::math::fft {
4243 }
4344 }
4445
45- void mul (auto &&C, auto const & D, auto &res, size_t k, [[maybe_unused]] base ifactor ) {
46+ void mul (auto &&C, auto const & D, auto &res, size_t k) {
4647 assert (A.size () == C.size ());
4748 size_t n = A.size ();
4849 if (!n) {
@@ -83,7 +84,7 @@ namespace cp_algo::math::fft {
8384 B.ifft ();
8485 C.ifft ();
8586 auto splitsplit = (base (split) * split).rem ();
86- base cur = 1 ;
87+ base cur = ifactor * ifactor ;
8788 base step = bpow (ifactor, n);
8889 cvector::exec_on_roots (2 * n, std::min (n, k), [&](size_t i, point rt) {
8990 rt = conj (rt);
@@ -95,23 +96,22 @@ namespace cp_algo::math::fft {
9596 int64_t A2 = llround (real (Bi));
9697 res[i] = A0 + A1 * split + A2 * splitsplit;
9798 res[i] *= cur;
98- if (n + i >= k) {
99- return ;
99+ if (n + i < k) {
100+ int64_t B0 = llround (imag (Ai));
101+ int64_t B1 = llround (imag (Ci));
102+ int64_t B2 = llround (imag (Bi));
103+ res[n + i] = B0 + B1 * split + B2 * splitsplit;
104+ res[n + i] *= cur * step;
100105 }
101- int64_t B0 = llround (imag (Ai));
102- int64_t B1 = llround (imag (Ci));
103- int64_t B2 = llround (imag (Bi));
104- res[n + i] = B0 + B1 * split + B2 * splitsplit;
105- res[n + i] *= cur * step;
106106 cur *= ifactor;
107107 });
108108 checkpoint (" recover mod" );
109109 }
110110 void mul_inplace (auto &&B, auto & res, size_t k) {
111- mul (B.A , B.B , res, k, ifactor * B. ifactor );
111+ mul (B.A , B.B , res, k);
112112 }
113113 void mul (auto const & B, auto & res, size_t k) {
114- mul (cvector (B.A ), B.B , res, k, ifactor * B. ifactor );
114+ mul (cvector (B.A ), B.B , res, k);
115115 }
116116 std::vector<base> operator *= (dft &B) {
117117 std::vector<base> res (2 * A.size ());
@@ -132,6 +132,7 @@ namespace cp_algo::math::fft {
132132 template <modint_type base> base dft<base>::factor = 1 ;
133133 template <modint_type base> base dft<base>::ifactor = 1 ;
134134 template <modint_type base> bool dft<base>::init = false ;
135+ template <modint_type base> int dft<base>::split = 1 ;
135136
136137 void mul_slow (auto &a, auto const & b, size_t k) {
137138 if (empty (a) || empty (b)) {
0 commit comments