11#ifndef CP_ALGO_NUMBER_THEORY_MODINT_HPP
22#define CP_ALGO_NUMBER_THEORY_MODINT_HPP
3- #include " cp-algo /math/common.hpp"
3+ #include " .. /math/common.hpp"
44#include < iostream>
5+ #include < cassert>
56namespace cp_algo ::math {
7+ inline constexpr uint64_t inv64 (uint64_t x) {
8+ assert (x % 2 );
9+ uint64_t y = 1 ;
10+ while (y * x != 1 ) {
11+ y *= 2 - x * y;
12+ }
13+ return y;
14+ }
15+
616 template <typename modint>
717 struct modint_base {
818 static int64_t mod () {
919 return modint::mod ();
1020 }
21+ static uint64_t imod () {
22+ return modint::imod ();
23+ }
24+ static __uint128_t pw128 () {
25+ return modint::pw128 ();
26+ }
27+ static uint64_t m_reduce (__uint128_t ab) {
28+ if (mod () % 2 == 0 ) {
29+ return ab % mod ();
30+ } else {
31+ uint64_t m = ab * imod ();
32+ int64_t res = (ab + __uint128_t (m) * mod ()) >> 64 ;
33+ return res < mod () ? res : res - mod ();
34+ }
35+ }
36+ static uint64_t m_transform (uint64_t a) {
37+ if (mod () % 2 == 0 ) {
38+ return a;
39+ } else {
40+ return m_reduce (a * pw128 ());
41+ }
42+ }
1143 modint_base (): r(0 ) {}
1244 modint_base (int64_t rr): r(rr % mod()) {
1345 r = std::min (r, r + mod ());
46+ r = m_transform (r);
1447 }
1548 modint inv () const {
1649 return bpow (to_modint (), mod () - 2 );
1750 }
18- modint operator - () const {return std::min (-r, mod () - r);}
51+ modint operator - () const {
52+ modint neg;
53+ neg.r = std::min (-r, mod () - r);
54+ return neg;
55+ }
1956 modint& operator /= (const modint &t) {
2057 return to_modint () *= t.inv ();
2158 }
2259 modint& operator *= (const modint &t) {
23- if (mod () <= uint32_t (-1 )) {
24- r = r * t.r % mod ();
25- } else {
26- r = __int128 (r) * t.r % mod ();
27- }
60+ r = m_reduce (__uint128_t (r) * t.r );
2861 return to_modint ();
2962 }
3063 modint& operator += (const modint &t) {
@@ -40,7 +73,10 @@ namespace cp_algo::math {
4073 modint operator * (const modint &t) const {return modint (to_modint ()) *= t;}
4174 modint operator / (const modint &t) const {return modint (to_modint ()) /= t;}
4275 auto operator <=> (const modint_base &t) const = default ;
43- int64_t rem () const {return 2 * r > (uint64_t )mod () ? r - mod () : r;}
76+ int64_t rem () const {
77+ uint64_t R = getr ();
78+ return 2 * R > (uint64_t )mod () ? R - mod () : R;
79+ }
4480
4581 // Only use if you really know what you're doing!
4682 uint64_t modmod () const {return 8ULL * mod () * mod ();};
@@ -52,16 +88,21 @@ namespace cp_algo::math {
5288 }
5389 return to_modint ();
5490 }
55- uint64_t & setr () {return r;}
56- uint64_t getr () const {return r;}
91+ void setr (uint64_t rr) {r = m_transform (rr);}
92+ uint64_t getr () const {return m_reduce (r);}
93+ void setr_direct (uint64_t rr) {r = rr;}
94+ uint64_t getr_direct () const {return r;}
5795 private:
5896 uint64_t r;
5997 modint& to_modint () {return static_cast <modint&>(*this );}
6098 modint const & to_modint () const {return static_cast <modint const &>(*this );}
6199 };
62100 template <typename modint>
63101 std::istream& operator >> (std::istream &in, modint_base<modint> &x) {
64- return in >> x.setr ();
102+ uint64_t r;
103+ auto &res = in >> r;
104+ x.setr (r);
105+ return res;
65106 }
66107 template <typename modint>
67108 std::ostream& operator << (std::ostream &out, modint_base<modint> const & x) {
@@ -73,14 +114,24 @@ namespace cp_algo::math {
73114
74115 template <int64_t m>
75116 struct modint : modint_base<modint<m>> {
117+ static constexpr uint64_t im = m % 2 ? inv64(-m) : 0 ;
118+ static constexpr uint64_t r2 = __uint128_t (-1 ) % m + 1 ;
76119 static constexpr int64_t mod () {return m;}
120+ static constexpr uint64_t imod () {return im;}
121+ static constexpr __uint128_t pw128 () {return r2;}
77122 using Base = modint_base<modint<m>>;
78123 using Base::Base;
79124 };
80125
81126 struct dynamic_modint : modint_base<dynamic_modint> {
82127 static int64_t mod () {return m;}
83- static void switch_mod (int64_t nm) {m = nm;}
128+ static uint64_t imod () {return im;}
129+ static __uint128_t pw128 () {return r2;}
130+ static void switch_mod (int64_t nm) {
131+ m = nm;
132+ im = m % 2 ? inv64 (-m) : 0 ;
133+ r2 = __uint128_t (-1 ) % m + 1 ;
134+ }
84135 using Base = modint_base<dynamic_modint>;
85136 using Base::Base;
86137
@@ -95,7 +146,10 @@ namespace cp_algo::math {
95146 }
96147 private:
97148 static int64_t m;
149+ static uint64_t im, r1, r2;
98150 };
99- int64_t dynamic_modint::m = 0 ;
151+ int64_t dynamic_modint::m = 1 ;
152+ uint64_t dynamic_modint::im = -1 ;
153+ uint64_t dynamic_modint::r2 = 0 ;
100154}
101155#endif // CP_ALGO_MATH_MODINT_HPP
0 commit comments