4
4
#include < iostream>
5
5
#include < cassert>
6
6
namespace cp_algo ::math {
7
- inline constexpr uint64_t inv64 ( uint64_t x) {
7
+ inline constexpr auto inv2 ( auto x) {
8
8
assert (x % 2 );
9
- uint64_t y = 1 ;
9
+ std:: make_unsigned_t < decltype (x)> y = 1 ;
10
10
while (y * x != 1 ) {
11
11
y *= 2 - x * y;
12
12
}
13
13
return y;
14
14
}
15
15
16
- template <typename modint>
16
+ template <typename modint, typename _Int >
17
17
struct modint_base {
18
- static int64_t mod () {
18
+ using Int = _Int;
19
+ using Uint = std::make_unsigned_t <Int>;
20
+ static constexpr size_t bits = sizeof (Int) * 8 ;
21
+ using Int2 = std::conditional_t <bits <= 32 , uint64_t , __uint128_t >;
22
+ static Int mod () {
19
23
return modint::mod ();
20
24
}
21
- static uint64_t imod () {
25
+ static Uint imod () {
22
26
return modint::imod ();
23
27
}
24
- static __uint128_t pw128 () {
28
+ static Int2 pw128 () {
25
29
return modint::pw128 ();
26
30
}
27
- static uint64_t m_reduce (__uint128_t ab) {
31
+ static Uint m_reduce (Int2 ab) {
28
32
if (mod () % 2 == 0 ) [[unlikely]] {
29
33
return ab % mod ();
30
34
} else {
31
- uint64_t m = ab * imod ();
32
- return (ab + __uint128_t (m) * mod ()) >> 64 ;
35
+ Uint m = ab * imod ();
36
+ return (ab + (Int2)m * mod ()) >> bits ;
33
37
}
34
38
}
35
- static uint64_t m_transform (uint64_t a) {
39
+ static Uint m_transform (Uint a) {
36
40
if (mod () % 2 == 0 ) [[unlikely]] {
37
41
return a;
38
42
} else {
39
43
return m_reduce (a * pw128 ());
40
44
}
41
45
}
42
46
modint_base (): r(0 ) {}
43
- modint_base (int64_t rr): r(rr % mod()) {
47
+ modint_base (Int rr): r(rr % mod()) {
44
48
r = std::min (r, r + mod ());
45
49
r = m_transform (r);
46
50
}
@@ -56,7 +60,7 @@ namespace cp_algo::math {
56
60
return to_modint () *= t.inv ();
57
61
}
58
62
modint& operator *= (const modint &t) {
59
- r = m_reduce (__uint128_t (r) * t.r );
63
+ r = m_reduce ((Int2)r * t.r );
60
64
return to_modint ();
61
65
}
62
66
modint& operator += (const modint &t) {
@@ -78,86 +82,89 @@ namespace cp_algo::math {
78
82
auto operator >= (const modint_base &t) const {return getr () >= t.getr ();}
79
83
auto operator < (const modint_base &t) const {return getr () < t.getr ();}
80
84
auto operator > (const modint_base &t) const {return getr () > t.getr ();}
81
- int64_t rem () const {
82
- uint64_t R = getr ();
83
- return 2 * R > (uint64_t )mod () ? R - mod () : R;
85
+ Int rem () const {
86
+ Uint R = getr ();
87
+ return 2 * R > (Uint )mod () ? R - mod () : R;
84
88
}
85
89
86
90
// Only use if you really know what you're doing!
87
- uint64_t modmod () const {return 8ULL * mod () * mod ();};
88
- void add_unsafe (uint64_t t) {r += t;}
91
+ Uint modmod () const {return (Uint) 8 * mod () * mod ();};
92
+ void add_unsafe (Uint t) {r += t;}
89
93
void pseudonormalize () {r = std::min (r, r - modmod ());}
90
94
modint const & normalize () {
91
- if (r >= (uint64_t )mod ()) {
95
+ if (r >= (Uint )mod ()) {
92
96
r %= mod ();
93
97
}
94
98
return to_modint ();
95
99
}
96
- void setr (uint64_t rr) {r = m_transform (rr);}
97
- uint64_t getr () const {
98
- uint64_t res = m_reduce (r);
100
+ void setr (Uint rr) {r = m_transform (rr);}
101
+ Uint getr () const {
102
+ Uint res = m_reduce (r);
99
103
return std::min (res, res - mod ());
100
104
}
101
- void setr_direct (uint64_t rr) {r = rr;}
102
- uint64_t getr_direct () const {return r;}
105
+ void setr_direct (Uint rr) {r = rr;}
106
+ Uint getr_direct () const {return r;}
103
107
private:
104
- uint64_t r;
108
+ Uint r;
105
109
modint& to_modint () {return static_cast <modint&>(*this );}
106
110
modint const & to_modint () const {return static_cast <modint const &>(*this );}
107
111
};
108
112
template <typename modint>
109
- std::istream& operator >> (std::istream &in, modint_base<modint> &x) {
110
- uint64_t r;
113
+ concept modint_type = std::is_base_of_v<modint_base<modint, typename modint::Int>, modint>;
114
+ template <modint_type modint>
115
+ std::istream& operator >> (std::istream &in, modint &x) {
116
+ typename modint::Uint r;
111
117
auto &res = in >> r;
112
118
x.setr (r);
113
119
return res;
114
120
}
115
- template <typename modint>
116
- std::ostream& operator << (std::ostream &out, modint_base< modint> const & x) {
121
+ template <modint_type modint>
122
+ std::ostream& operator << (std::ostream &out, modint const & x) {
117
123
return out << x.getr ();
118
124
}
119
125
120
- template <typename modint>
121
- concept modint_type = std::is_base_of_v<modint_base<modint>, modint>;
122
-
123
- template <int64_t m>
124
- struct modint : modint_base<modint<m>> {
125
- static constexpr uint64_t im = m % 2 ? inv64(-m) : 0 ;
126
- static constexpr uint64_t r2 = __uint128_t (-1 ) % m + 1 ;
127
- static constexpr int64_t mod () {return m;}
128
- static constexpr uint64_t imod () {return im;}
129
- static constexpr __uint128_t pw128 () {return r2;}
130
- using Base = modint_base<modint<m>>;
126
+ template <auto m>
127
+ struct modint : modint_base<modint<m>, decltype (m)> {
128
+ using Base = modint_base<modint<m>, decltype (m)>;
131
129
using Base::Base;
130
+ static constexpr Base::Uint im = m % 2 ? inv2(-m) : 0 ;
131
+ static constexpr Base::Uint r2 = (typename Base::Int2)(-1 ) % m + 1 ;
132
+ static constexpr Base::Int mod () {return m;}
133
+ static constexpr Base::Uint imod () {return im;}
134
+ static constexpr Base::Int2 pw128 () {return r2;}
132
135
};
133
136
134
- struct dynamic_modint : modint_base<dynamic_modint> {
135
- static int64_t mod () {return m;}
136
- static uint64_t imod () {return im;}
137
- static __uint128_t pw128 () {return r2;}
138
- static void switch_mod (int64_t nm) {
137
+ template <typename Int = int64_t >
138
+ struct dynamic_modint : modint_base<dynamic_modint<Int>, Int> {
139
+ using Base = modint_base<dynamic_modint<Int>, Int>;
140
+ using Base::Base;
141
+ static Int mod () {return m;}
142
+ static Base::Uint imod () {return im;}
143
+ static Base::Int2 pw128 () {return r2;}
144
+ static void switch_mod (Int nm) {
139
145
m = nm;
140
- im = m % 2 ? inv64 (-m) : 0 ;
141
- r2 = __uint128_t (-1 ) % m + 1 ;
146
+ im = m % 2 ? inv2 (-m) : 0 ;
147
+ r2 = ( typename Base::Int2) (-1 ) % m + 1 ;
142
148
}
143
- using Base = modint_base<dynamic_modint>;
144
- using Base::Base;
145
149
146
150
// Wrapper for temp switching
147
- auto static with_mod (int64_t tmp, auto callback) {
151
+ auto static with_mod (Int tmp, auto callback) {
148
152
struct scoped {
149
- int64_t prev = mod();
153
+ Int prev = mod();
150
154
~scoped () {switch_mod (prev);}
151
155
} _;
152
156
switch_mod (tmp);
153
157
return callback ();
154
158
}
155
159
private:
156
- static int64_t m;
157
- static uint64_t im, r1, r2;
160
+ static Int m;
161
+ static Base::Uint im, r1, r2;
158
162
};
159
- int64_t dynamic_modint::m = 1 ;
160
- uint64_t dynamic_modint::im = -1 ;
161
- uint64_t dynamic_modint::r2 = 0 ;
163
+ template <typename Int>
164
+ Int dynamic_modint<Int>::m = 1 ;
165
+ template <typename Int>
166
+ dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::im = -1 ;
167
+ template <typename Int>
168
+ dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::r2 = 0 ;
162
169
}
163
170
#endif // CP_ALGO_MATH_MODINT_HPP
0 commit comments