@@ -14,12 +14,15 @@ namespace cp_algo::math::fft {
1414 using Int2 = base::Int2;
1515 static bool _init;
1616 static int split;
17+ static u64x4 mod, imod;
1718
1819 void init () {
1920 if (!_init) {
2021 factor = 1 + random::rng () % (base::mod () - 1 );
2122 split = int (std::sqrt (base::mod ())) + 1 ;
2223 ifactor = base (1 ) / factor;
24+ mod = u64x4 () + base::mod ();
25+ imod = u64x4 () + inv2 (-base::mod ());
2326 _init = true ;
2427 }
2528 }
@@ -79,28 +82,38 @@ namespace cp_algo::math::fft {
7982 }
8083
8184 void recover_mod (auto &&C, auto &res, size_t k) {
85+ assert (size (res) % flen == 0 );
8286 size_t n = A.size ();
8387 auto splitsplit = base (split * split).getr ();
84- base stepn = bpow (ifactor, n);
85- base cur[] = {bpow (ifactor, 2 ), bpow (ifactor, 3 ), bpow (ifactor, 4 ), bpow (ifactor, 5 )};
86- base step4 = cur[2 ];
88+ base b2x32 = bpow (base (2 ), 32 );
89+ base b2x64 = bpow (base (2 ), 64 );
90+ u64x4 cur = {
91+ (bpow (ifactor, 2 ) * b2x64).getr (),
92+ (bpow (ifactor, 3 ) * b2x64).getr (),
93+ (bpow (ifactor, 4 ) * b2x64).getr (),
94+ (bpow (ifactor, 5 ) * b2x64).getr ()
95+ };
96+ u64x4 step4 = u64x4{} + (bpow (ifactor, 4 ) * b2x32).getr ();
97+ u64x4 stepn = u64x4{} + (bpow (ifactor, n) * b2x32).getr ();
8798 for (size_t i = 0 ; i < std::min (n, k); i += flen) {
8899 auto [Ax, Ay] = A.at (i);
89100 auto [Bx, By] = B.at (i);
90101 auto [Cx, Cy] = C.at (i);
91- auto A0 = lround (Ax), A1 = lround (Cx), A2 = lround (Bx);
92- auto B0 = lround (Ay), B1 = lround (Cy), B2 = lround (By);
93- for (size_t j = 0 ; j < flen; j++) {
94- if (i + j < k) {
95- res[i + j] = A0[j] + A1[j] * split + A2[j] * splitsplit;
96- res[i + j] *= cur[j];
102+ auto set_i = [&](size_t i, auto A, auto B, auto C, auto mul) {
103+ auto A0 = lround (A), A1 = lround (C), A2 = lround (B);
104+ auto Ai = A0 + A1 * split + A2 * splitsplit + base::modmod ();
105+ auto Au = montgomery_reduce (u64x4 (Ai), mod, imod);
106+ Au = montgomery_mul (Au, mul, mod, imod);
107+ Au = Au >= base::mod () ? Au - base::mod () : Au;
108+ for (size_t j = 0 ; j < flen; j++) {
109+ res[i + j].setr (Au[j]);
97110 }
98- if (i + j + n < k) {
99- res[i + j + n] = B0[j] + B1[j] * split + B2[j] * splitsplit;
100- res[i + j + n] *= cur[j] * stepn;
101- }
102- cur[j] *= step4;
111+ };
112+ set_i (i, Ax, Bx, Cx, cur);
113+ if (i + n < k) {
114+ set_i (i + n, Ay, By, Cy, montgomery_mul (cur, stepn, mod, imod));
103115 }
116+ cur = montgomery_mul (cur, step4, mod, imod);
104117 }
105118 checkpoint (" recover mod" );
106119 }
@@ -144,6 +157,8 @@ namespace cp_algo::math::fft {
144157 template <modint_type base> base dft<base>::ifactor = 1 ;
145158 template <modint_type base> bool dft<base>::_init = false ;
146159 template <modint_type base> int dft<base>::split = 1 ;
160+ template <modint_type base> u64x4 dft<base>::mod = {};
161+ template <modint_type base> u64x4 dft<base>::imod = {};
147162
148163 void mul_slow (auto &a, auto const & b, size_t k) {
149164 if (empty (a) || empty (b)) {
@@ -176,13 +191,13 @@ namespace cp_algo::math::fft {
176191 std::min (k, size (a)) + std::min (k, size (b)) - 1
177192 ) / 2 );
178193 auto A = dft<base>(a | std::views::take (k), n);
179- a.assign (k, 0 );
180- checkpoint (" reset a" );
194+ a.assign ((k / flen + 1 ) * flen, 0 );
181195 if (&a == &b) {
182196 A.mul (A, a, k);
183197 } else {
184198 A.mul_inplace (dft<base>(b | std::views::take (k), n), a, k);
185199 }
200+ a.resize (k);
186201 }
187202 void mul (auto &a, auto const & b) {
188203 size_t N = size (a) + size (b) - 1 ;
@@ -213,6 +228,7 @@ namespace cp_algo::math::fft {
213228 a[i + n] += ai;
214229 }
215230 a.resize (N);
231+ checkpoint (" karatsuba join" );
216232 } else if (size (a)) {
217233 mul_truncate (a, b, N);
218234 }
0 commit comments