44#include < iostream>
55#include < cassert>
66namespace cp_algo ::math {
7- inline constexpr auto inv2 (auto x) {
8- assert (x % 2 );
9- std::make_unsigned_t <decltype (x)> y = 1 ;
10- while (y * x != 1 ) {
11- y *= 2 - x * y;
12- }
13- return y;
14- }
157
168 template <typename modint, typename _Int>
179 struct modint_base {
@@ -23,97 +15,76 @@ namespace cp_algo::math {
2315 static Int mod () {
2416 return modint::mod ();
2517 }
26- static UInt imod () {
27- return modint::imod ();
18+ static Int remod () {
19+ return modint::remod ();
2820 }
29- static UInt2 pw128 () {
30- return modint::pw128 ();
31- }
32- static UInt m_reduce (UInt2 ab) {
33- if (mod () % 2 == 0 ) [[unlikely]] {
34- return UInt (ab % mod ());
35- } else {
36- UInt2 m = (UInt)ab * imod ();
37- return UInt ((ab + m * mod ()) >> bits);
38- }
39- }
40- static UInt m_reduce (Int2 ab) {
41- return m_reduce (UInt2 (ab + UInt2 (ab < 0 ) * mod () * mod ()));
42- }
43- static UInt m_transform (UInt a) {
44- if (mod () % 2 == 0 ) [[unlikely]] {
45- return a;
46- } else {
47- return m_reduce (a * pw128 ());
48- }
21+ static UInt2 modmod () {
22+ return UInt2 (mod ()) * mod ();
4923 }
5024 modint_base (): r(0 ) {}
51- modint_base (Int2 rr): r(UInt(rr % mod())) {
52- r = std::min (r, r + mod ());
53- r = m_transform (r);
25+ modint_base (Int2 rr) {
26+ to_modint ().setr (UInt ((rr + modmod ()) % mod ()));
5427 }
5528 modint inv () const {
5629 return bpow (to_modint (), mod () - 2 );
5730 }
5831 modint operator - () const {
5932 modint neg;
60- neg.r = std::min (-r, 2 * mod () - r);
33+ neg.r = std::min (-r, remod () - r);
6134 return neg;
6235 }
6336 modint& operator /= (const modint &t) {
6437 return to_modint () *= t.inv ();
6538 }
6639 modint& operator *= (const modint &t) {
67- r = m_reduce (( UInt2)r * t.r );
40+ r = UInt ( UInt2 (r) * t.r % mod () );
6841 return to_modint ();
6942 }
7043 modint& operator += (const modint &t) {
71- r += t.r ; r = std::min (r, r - 2 * mod ());
44+ r += t.r ; r = std::min (r, r - remod ());
7245 return to_modint ();
7346 }
7447 modint& operator -= (const modint &t) {
75- r -= t.r ; r = std::min (r, r + 2 * mod ());
48+ r -= t.r ; r = std::min (r, r + remod ());
7649 return to_modint ();
7750 }
7851 modint operator + (const modint &t) const {return modint (to_modint ()) += t;}
7952 modint operator - (const modint &t) const {return modint (to_modint ()) -= t;}
8053 modint operator * (const modint &t) const {return modint (to_modint ()) *= t;}
8154 modint operator / (const modint &t) const {return modint (to_modint ()) /= t;}
8255 // Why <=> doesn't work?..
83- auto operator == (const modint_base &t) const {return getr () == t.getr ();}
84- auto operator != (const modint_base &t) const {return getr () != t.getr ();}
85- auto operator <= (const modint_base &t) const {return getr () <= t.getr ();}
86- auto operator >= (const modint_base &t) const {return getr () >= t.getr ();}
87- auto operator < (const modint_base &t) const {return getr () < t.getr ();}
88- auto operator > (const modint_base &t) const {return getr () > t.getr ();}
56+ auto operator == (const modint &t) const {return to_modint (). getr () == t.getr ();}
57+ auto operator != (const modint &t) const {return to_modint (). getr () != t.getr ();}
58+ auto operator <= (const modint &t) const {return to_modint (). getr () <= t.getr ();}
59+ auto operator >= (const modint &t) const {return to_modint (). getr () >= t.getr ();}
60+ auto operator < (const modint &t) const {return to_modint (). getr () < t.getr ();}
61+ auto operator > (const modint &t) const {return to_modint (). getr () > t.getr ();}
8962 Int rem () const {
90- UInt R = getr ();
63+ UInt R = to_modint (). getr ();
9164 return 2 * R > (UInt)mod () ? R - mod () : R;
9265 }
66+ void setr (UInt rr) {
67+ r = rr;
68+ }
69+ UInt getr () const {
70+ return r;
71+ }
9372
94- // Only use if you really know what you're doing!
95- static UInt modmod () {return ( UInt) 8 * mod () * mod ( );};
73+ // Only use these if you really know what you're doing!
74+ static UInt modmod8 () {return UInt ( 8 * modmod () );}
9675 void add_unsafe (UInt t) {r += t;}
97- void pseudonormalize () {r = std::min (r, r - modmod ());}
76+ void pseudonormalize () {r = std::min (r, r - modmod8 ());}
9877 modint const & normalize () {
9978 if (r >= (UInt)mod ()) {
10079 r %= mod ();
10180 }
10281 return to_modint ();
10382 }
104- void setr (UInt rr) {r = m_transform (rr);}
105- UInt getr () const {
106- UInt res = m_reduce (UInt2 (r));
107- return std::min (res, res - mod ());
108- }
10983 void setr_direct (UInt rr) {r = rr;}
11084 UInt getr_direct () const {return r;}
111- Int rem_direct () const {
112- UInt R = std::min (r, r - mod ());
113- return 2 * R > (UInt)mod () ? R - mod () : R;
114- }
115- private:
85+ protected:
11686 UInt r;
87+ private:
11788 modint& to_modint () {return static_cast <modint&>(*this );}
11889 modint const & to_modint () const {return static_cast <modint const &>(*this );}
11990 };
@@ -135,18 +106,53 @@ namespace cp_algo::math {
135106 struct modint : modint_base<modint<m>, decltype (m)> {
136107 using Base = modint_base<modint<m>, decltype (m)>;
137108 using Base::Base;
138- static constexpr Base::UInt im = m % 2 ? inv2(-m) : 0 ;
139- static constexpr Base::UInt r2 = (typename Base::UInt2)(-1 ) % m + 1 ;
140109 static constexpr Base::Int mod () {return m;}
141- static constexpr Base::UInt imod () {return im ;}
142- static constexpr Base::UInt2 pw128 () {return r2 ;}
110+ static constexpr Base::UInt remod () {return m ;}
111+ auto getr () const {return Base::r ;}
143112 };
144113
114+ inline constexpr auto inv2 (auto x) {
115+ assert (x % 2 );
116+ std::make_unsigned_t <decltype (x)> y = 1 ;
117+ while (y * x != 1 ) {
118+ y *= 2 - x * y;
119+ }
120+ return y;
121+ }
122+
145123 template <typename Int = int64_t >
146124 struct dynamic_modint : modint_base<dynamic_modint<Int>, Int> {
147125 using Base = modint_base<dynamic_modint<Int>, Int>;
148126 using Base::Base;
127+
128+ static Base::UInt m_reduce (Base::UInt2 ab) {
129+ if (mod () % 2 == 0 ) [[unlikely]] {
130+ return typename Base::UInt (ab % mod ());
131+ } else {
132+ typename Base::UInt2 m = typename Base::UInt (ab) * imod ();
133+ return typename Base::UInt ((ab + m * mod ()) >> Base::bits);
134+ }
135+ }
136+ static Base::UInt m_transform (Base::UInt a) {
137+ if (mod () % 2 == 0 ) [[unlikely]] {
138+ return a;
139+ } else {
140+ return m_reduce (a * pw128 ());
141+ }
142+ }
143+ dynamic_modint& operator *= (const dynamic_modint &t) {
144+ Base::r = m_reduce (typename Base::UInt2 (Base::r) * t.r );
145+ return *this ;
146+ }
147+ void setr (Base::UInt rr) {
148+ Base::r = m_transform (rr);
149+ }
150+ Base::UInt getr () const {
151+ typename Base::UInt res = m_reduce (Base::r);
152+ return std::min (res, res - mod ());
153+ }
149154 static Int mod () {return m;}
155+ static Int remod () {return 2 * m;}
150156 static Base::UInt imod () {return im;}
151157 static Base::UInt2 pw128 () {return r2;}
152158 static void switch_mod (Int nm) {
0 commit comments