@@ -70,25 +70,38 @@ namespace cp_algo::math::poly::impl {
70
70
}
71
71
return powmod_hint (p, k, md, md.reverse ().inv (md.deg () + 1 ));
72
72
}
73
-
74
- auto interleave (auto const & p) {
75
- auto [p0, p1] = p.bisect ();
76
- return p0 * p0 - (p1 * p1).mul_xk (1 );
77
- }
78
73
template <typename poly>
79
- poly inv (poly const & q, int64_t k, size_t n) {
74
+ poly& inv_inplace (poly& q, int64_t k, size_t n) {
75
+ using poly_t = std::decay_t <poly>;
76
+ using base = poly_t ::base;
80
77
if (k <= std::max<int64_t >(n, size (q.a ))) {
81
- return q.inv (k + n).div_xk (k);
78
+ return q.inv_inplace (k + n).div_xk_inplace (k);
82
79
}
83
80
if (k % 2 ) {
84
- return inv (q, k - 1 , n + 1 ).div_xk (1 );
81
+ return inv_inplace (q, k - 1 , n + 1 ).div_xk_inplace (1 );
85
82
}
86
-
87
- auto qq = inv (interleave (q), k / 2 - q.deg () / 2 , (n + 1 ) / 2 + q.deg () / 2 );
88
- auto [q0, q1] = q.negx ().bisect ();
89
- return (
90
- (q0 * qq).x2 () + (q1 * qq).x2 ().mul_xk (1 )
91
- ).div_xk (2 *q0.deg ()).mod_xk (n);
83
+ auto [q0, q1] = q.bisect ();
84
+ auto qq = q0 * q0 - (q1 * q1).mul_xk_inplace (1 );
85
+ inv_inplace (qq, k / 2 - q.deg () / 2 , (n + 1 ) / 2 + q.deg () / 2 );
86
+ int N = fft::com_size (size (q0.a ), size (qq.a ));
87
+ auto q0f = fft::dft<base>(q0.a , N);
88
+ auto q1f = fft::dft<base>(q1.a , N);
89
+ auto qqf = fft::dft<base>(qq.a , N);
90
+ int M = q0.deg () + (n + 1 ) / 2 ;
91
+ std::deque<base> A (M), B (M);
92
+ q0f.mul (fft::dft<base>(qqf), A, M);
93
+ q1f.mul (qqf, B, M);
94
+ q.a .resize (n + 1 );
95
+ for (size_t i = 0 ; i < n; i += 2 ) {
96
+ q.a [i] = A[q0.deg () + i / 2 ];
97
+ q.a [i + 1 ] = -B[q0.deg () + i / 2 ];
98
+ }
99
+ q.a .pop_back ();
100
+ q.normalize ();
101
+ return q;
102
+
103
+ q = (q0 * qq).x2 () - (q1 * qq).x2 ().mul_xk (1 );
104
+ return q.div_xk_inplace (2 * q0.deg ()).mod_xk_inplace (n);
92
105
}
93
106
template <typename poly>
94
107
poly& inv_inplace (poly& p, size_t n) {
@@ -100,7 +113,7 @@ namespace cp_algo::math::poly::impl {
100
113
// Q(-x) = P0(x^2) + xP1(x^2)
101
114
auto [q0, q1] = p.bisect (n);
102
115
103
- int N = fft::com_size ((n + 1 ) / 2 , (n + 1 ) / 2 );
116
+ int N = fft::com_size (size (q0. a ) , (n + 1 ) / 2 );
104
117
105
118
auto q0f = fft::dft<base>(q0.a , N);
106
119
auto q1f = fft::dft<base>(q1.a , N);
0 commit comments