@@ -12,23 +12,27 @@ namespace cp_algo::math::fft {
1212 cvector A, B;
1313 static base factor, ifactor;
1414 using Int2 = base::Int2;
15- static bool init ;
15+ static bool _init ;
1616 static int split;
1717
18- dft ( auto const & a, size_t n): A(n), B(n ) {
19- if (!init ) {
18+ void init ( ) {
19+ if (!_init ) {
2020 factor = 1 + random::rng () % (base::mod () - 1 );
2121 split = int (std::sqrt (base::mod ())) + 1 ;
2222 ifactor = base (1 ) / factor;
23- init = true ;
23+ _init = true ;
2424 }
25+ }
26+
27+ dft (auto const & a, size_t n): A(n), B(n) {
28+ init ();
2529 base cur = factor;
2630 base step = bpow (factor, n);
2731 for (size_t i = 0 ; i < std::min (n, size (a)); i++) {
2832 auto splt = [&](size_t i, auto mul) {
29- auto ai = i < size (a) ? (a[i] * mul).rem () : 0 ;
30- auto rem = ai % split;
33+ auto ai = i < size (a) ? (a[i] * mul).getr () : 0 ;
3134 auto quo = ai / split;
35+ auto rem = ai % split;
3236 return std::pair{(ftype)rem, (ftype)quo};
3337 };
3438 auto [rai, qai] = splt (i, cur);
@@ -74,6 +78,33 @@ namespace cp_algo::math::fft {
7478 checkpoint (" dot" );
7579 }
7680
81+ void recover_mod (auto &&C, auto &res, size_t k) {
82+ size_t n = A.size ();
83+ auto splitsplit = base (split * split).getr ();
84+ base stepn = bpow (ifactor, n);
85+ base cur[] = {bpow (ifactor, 2 ), bpow (ifactor, 3 ), bpow (ifactor, 4 ), bpow (ifactor, 5 )};
86+ base step4 = cur[2 ];
87+ for (size_t i = 0 ; i < std::min (n, k); i += flen) {
88+ auto [Ax, Ay] = A.at (i);
89+ auto [Bx, By] = B.at (i);
90+ auto [Cx, Cy] = C.at (i);
91+ auto A0 = lround (Ax), A1 = lround (Cx), A2 = lround (Bx);
92+ auto B0 = lround (Ay), B1 = lround (Cy), B2 = lround (By);
93+ for (size_t j = 0 ; j < flen; j++) {
94+ if (i + j < k) {
95+ res[i + j] = A0[j] + A1[j] * split + A2[j] * splitsplit;
96+ res[i + j] *= cur[j];
97+ }
98+ if (i + j + n < k) {
99+ res[i + j + n] = B0[j] + B1[j] * split + B2[j] * splitsplit;
100+ res[i + j + n] *= cur[j] * stepn;
101+ }
102+ cur[j] *= step4;
103+ }
104+ }
105+ checkpoint (" recover mod" );
106+ }
107+
77108 void mul (auto &&C, auto const & D, auto &res, size_t k) {
78109 assert (A.size () == C.size ());
79110 size_t n = A.size ();
@@ -85,28 +116,7 @@ namespace cp_algo::math::fft {
85116 A.ifft ();
86117 B.ifft ();
87118 C.ifft ();
88- auto splitsplit = (base (split) * split).rem ();
89- base cur = ifactor * ifactor;
90- base step = bpow (ifactor, n);
91- for (size_t i = 0 ; i < std::min (n, k); i++) {
92- auto [Ax, Ay] = A.get (i);
93- auto [Bx, By] = B.get (i);
94- auto [Cx, Cy] = C.get (i);
95- Int2 A0 = llround (Ax);
96- Int2 A1 = llround (Cx);
97- Int2 A2 = llround (Bx);
98- res[i] = A0 + A1 * split + A2 * splitsplit;
99- res[i] *= cur;
100- if (n + i < k) {
101- Int2 B0 = llround (Ay);
102- Int2 B1 = llround (Cy);
103- Int2 B2 = llround (By);
104- res[n + i] = B0 + B1 * split + B2 * splitsplit;
105- res[n + i] *= cur * step;
106- }
107- cur *= ifactor;
108- }
109- checkpoint (" recover mod" );
119+ recover_mod (C, res, k);
110120 }
111121 void mul_inplace (auto &&B, auto & res, size_t k) {
112122 mul (B.A , B.B , res, k);
@@ -132,7 +142,7 @@ namespace cp_algo::math::fft {
132142 };
133143 template <modint_type base> base dft<base>::factor = 1 ;
134144 template <modint_type base> base dft<base>::ifactor = 1 ;
135- template <modint_type base> bool dft<base>::init = false ;
145+ template <modint_type base> bool dft<base>::_init = false ;
136146 template <modint_type base> int dft<base>::split = 1 ;
137147
138148 void mul_slow (auto &a, auto const & b, size_t k) {
0 commit comments