Skip to content

Commit fc70117

Browse files
committed
Parameterize storage type for modint
1 parent 62350f5 commit fc70117

24 files changed

+86
-78
lines changed

cp-algo/linalg/vector.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ namespace cp_algo::linalg {
127127
using Base::Base;
128128

129129
void add_scaled(vec const& b, base scale, size_t i = 0) override {
130+
static_assert(base::bits >= 64, "Only wide modint types for linalg");
130131
uint64_t scaler = scale.getr();
131132
if(scale != base(0)) {
132133
for(; i < size(*this); i++) {

cp-algo/number_theory/discrete_log.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace cp_algo::math {
1313
return res ? std::optional(*res + 1) : res;
1414
}
1515
// a * b^x is periodic here
16-
using base = dynamic_modint;
16+
using base = dynamic_modint<>;
1717
return base::with_mod(m, [&]() -> std::optional<uint64_t> {
1818
size_t sqrtmod = std::max<size_t>(1, std::sqrt(m) / 2);
1919
std::unordered_map<int64_t, int> small;

cp-algo/number_theory/euler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace cp_algo::math {
2525
return ans;
2626
}
2727
int64_t primitive_root(int64_t p) {
28-
using base = dynamic_modint;
28+
using base = dynamic_modint<>;
2929
return base::with_mod(p, [p](){
3030
base t = 1;
3131
while(period(t) != p - 1) {

cp-algo/number_theory/factorize.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace cp_algo::math {
1111
} else if(is_prime(m)) {
1212
res.push_back(m);
1313
} else if(m > 1) {
14-
using base = dynamic_modint;
14+
using base = dynamic_modint<>;
1515
base::with_mod(m, [&]() {
1616
base t = random::rng();
1717
auto f = [&](auto x) {

cp-algo/number_theory/modint.hpp

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,47 @@
44
#include <iostream>
55
#include <cassert>
66
namespace cp_algo::math {
7-
inline constexpr uint64_t inv64(uint64_t x) {
7+
inline constexpr auto inv2(auto x) {
88
assert(x % 2);
9-
uint64_t y = 1;
9+
std::make_unsigned_t<decltype(x)> y = 1;
1010
while(y * x != 1) {
1111
y *= 2 - x * y;
1212
}
1313
return y;
1414
}
1515

16-
template<typename modint>
16+
template<typename modint, typename _Int>
1717
struct modint_base {
18-
static int64_t mod() {
18+
using Int = _Int;
19+
using Uint = std::make_unsigned_t<Int>;
20+
static constexpr size_t bits = sizeof(Int) * 8;
21+
using Int2 = std::conditional_t<bits <= 32, uint64_t, __uint128_t>;
22+
static Int mod() {
1923
return modint::mod();
2024
}
21-
static uint64_t imod() {
25+
static Uint imod() {
2226
return modint::imod();
2327
}
24-
static __uint128_t pw128() {
28+
static Int2 pw128() {
2529
return modint::pw128();
2630
}
27-
static uint64_t m_reduce(__uint128_t ab) {
31+
static Uint m_reduce(Int2 ab) {
2832
if(mod() % 2 == 0) [[unlikely]] {
2933
return ab % mod();
3034
} else {
31-
uint64_t m = ab * imod();
32-
return (ab + __uint128_t(m) * mod()) >> 64;
35+
Uint m = ab * imod();
36+
return (ab + (Int2)m * mod()) >> bits;
3337
}
3438
}
35-
static uint64_t m_transform(uint64_t a) {
39+
static Uint m_transform(Uint a) {
3640
if(mod() % 2 == 0) [[unlikely]] {
3741
return a;
3842
} else {
3943
return m_reduce(a * pw128());
4044
}
4145
}
4246
modint_base(): r(0) {}
43-
modint_base(int64_t rr): r(rr % mod()) {
47+
modint_base(Int rr): r(rr % mod()) {
4448
r = std::min(r, r + mod());
4549
r = m_transform(r);
4650
}
@@ -56,7 +60,7 @@ namespace cp_algo::math {
5660
return to_modint() *= t.inv();
5761
}
5862
modint& operator *= (const modint &t) {
59-
r = m_reduce(__uint128_t(r) * t.r);
63+
r = m_reduce((Int2)r * t.r);
6064
return to_modint();
6165
}
6266
modint& operator += (const modint &t) {
@@ -78,86 +82,89 @@ namespace cp_algo::math {
7882
auto operator >= (const modint_base &t) const {return getr() >= t.getr();}
7983
auto operator < (const modint_base &t) const {return getr() < t.getr();}
8084
auto operator > (const modint_base &t) const {return getr() > t.getr();}
81-
int64_t rem() const {
82-
uint64_t R = getr();
83-
return 2 * R > (uint64_t)mod() ? R - mod() : R;
85+
Int rem() const {
86+
Uint R = getr();
87+
return 2 * R > (Uint)mod() ? R - mod() : R;
8488
}
8589

8690
// Only use if you really know what you're doing!
87-
uint64_t modmod() const {return 8ULL * mod() * mod();};
88-
void add_unsafe(uint64_t t) {r += t;}
91+
Uint modmod() const {return (Uint)8 * mod() * mod();};
92+
void add_unsafe(Uint t) {r += t;}
8993
void pseudonormalize() {r = std::min(r, r - modmod());}
9094
modint const& normalize() {
91-
if(r >= (uint64_t)mod()) {
95+
if(r >= (Uint)mod()) {
9296
r %= mod();
9397
}
9498
return to_modint();
9599
}
96-
void setr(uint64_t rr) {r = m_transform(rr);}
97-
uint64_t getr() const {
98-
uint64_t res = m_reduce(r);
100+
void setr(Uint rr) {r = m_transform(rr);}
101+
Uint getr() const {
102+
Uint res = m_reduce(r);
99103
return std::min(res, res - mod());
100104
}
101-
void setr_direct(uint64_t rr) {r = rr;}
102-
uint64_t getr_direct() const {return r;}
105+
void setr_direct(Uint rr) {r = rr;}
106+
Uint getr_direct() const {return r;}
103107
private:
104-
uint64_t r;
108+
Uint r;
105109
modint& to_modint() {return static_cast<modint&>(*this);}
106110
modint const& to_modint() const {return static_cast<modint const&>(*this);}
107111
};
108112
template<typename modint>
109-
std::istream& operator >> (std::istream &in, modint_base<modint> &x) {
110-
uint64_t r;
113+
concept modint_type = std::is_base_of_v<modint_base<modint, typename modint::Int>, modint>;
114+
template<modint_type modint>
115+
std::istream& operator >> (std::istream &in, modint &x) {
116+
typename modint::Uint r;
111117
auto &res = in >> r;
112118
x.setr(r);
113119
return res;
114120
}
115-
template<typename modint>
116-
std::ostream& operator << (std::ostream &out, modint_base<modint> const& x) {
121+
template<modint_type modint>
122+
std::ostream& operator << (std::ostream &out, modint const& x) {
117123
return out << x.getr();
118124
}
119125

120-
template<typename modint>
121-
concept modint_type = std::is_base_of_v<modint_base<modint>, modint>;
122-
123-
template<int64_t m>
124-
struct modint: modint_base<modint<m>> {
125-
static constexpr uint64_t im = m % 2 ? inv64(-m) : 0;
126-
static constexpr uint64_t r2 = __uint128_t(-1) % m + 1;
127-
static constexpr int64_t mod() {return m;}
128-
static constexpr uint64_t imod() {return im;}
129-
static constexpr __uint128_t pw128() {return r2;}
130-
using Base = modint_base<modint<m>>;
126+
template<auto m>
127+
struct modint: modint_base<modint<m>, decltype(m)> {
128+
using Base = modint_base<modint<m>, decltype(m)>;
131129
using Base::Base;
130+
static constexpr Base::Uint im = m % 2 ? inv2(-m) : 0;
131+
static constexpr Base::Uint r2 = (typename Base::Int2)(-1) % m + 1;
132+
static constexpr Base::Int mod() {return m;}
133+
static constexpr Base::Uint imod() {return im;}
134+
static constexpr Base::Int2 pw128() {return r2;}
132135
};
133136

134-
struct dynamic_modint: modint_base<dynamic_modint> {
135-
static int64_t mod() {return m;}
136-
static uint64_t imod() {return im;}
137-
static __uint128_t pw128() {return r2;}
138-
static void switch_mod(int64_t nm) {
137+
template<typename Int = int64_t>
138+
struct dynamic_modint: modint_base<dynamic_modint<Int>, Int> {
139+
using Base = modint_base<dynamic_modint<Int>, Int>;
140+
using Base::Base;
141+
static Int mod() {return m;}
142+
static Base::Uint imod() {return im;}
143+
static Base::Int2 pw128() {return r2;}
144+
static void switch_mod(Int nm) {
139145
m = nm;
140-
im = m % 2 ? inv64(-m) : 0;
141-
r2 = __uint128_t(-1) % m + 1;
146+
im = m % 2 ? inv2(-m) : 0;
147+
r2 = (typename Base::Int2)(-1) % m + 1;
142148
}
143-
using Base = modint_base<dynamic_modint>;
144-
using Base::Base;
145149

146150
// Wrapper for temp switching
147-
auto static with_mod(int64_t tmp, auto callback) {
151+
auto static with_mod(Int tmp, auto callback) {
148152
struct scoped {
149-
int64_t prev = mod();
153+
Int prev = mod();
150154
~scoped() {switch_mod(prev);}
151155
} _;
152156
switch_mod(tmp);
153157
return callback();
154158
}
155159
private:
156-
static int64_t m;
157-
static uint64_t im, r1, r2;
160+
static Int m;
161+
static Base::Uint im, r1, r2;
158162
};
159-
int64_t dynamic_modint::m = 1;
160-
uint64_t dynamic_modint::im = -1;
161-
uint64_t dynamic_modint::r2 = 0;
163+
template<typename Int>
164+
Int dynamic_modint<Int>::m = 1;
165+
template<typename Int>
166+
dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::im = -1;
167+
template<typename Int>
168+
dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::r2 = 0;
162169
}
163170
#endif // CP_ALGO_MATH_MODINT_HPP

cp-algo/number_theory/primality.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace cp_algo::math {
1212
// m - 1 = 2^s * d
1313
int s = std::countr_zero(m - 1);
1414
auto d = (m - 1) >> s;
15-
using base = dynamic_modint;
15+
using base = dynamic_modint<>;
1616
auto test = [&](base x) {
1717
x = bpow(x, d);
1818
if(std::abs(x.rem()) <= 1) {

cp-algo/number_theory/two_squares.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace cp_algo::math {
1313
return gaussint(1, 1);
1414
}
1515
assert(p % 4 == 1);
16-
using base = dynamic_modint;
16+
using base = dynamic_modint<>;
1717
return base::with_mod(p, [&](){
1818
base g = primitive_root(p);
1919
int64_t i = bpow(g, (p - 1) / 4).getr();

verify/linalg/adj.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "cp-algo/linalg/matrix.hpp"
55
#include <bits/stdc++.h>
66

7-
const int mod = 998244353;
7+
const int64_t mod = 998244353;
88

99
using namespace std;
1010
using cp_algo::math::modint;

verify/linalg/characteristic.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using namespace std;
1010
using namespace cp_algo::math;
1111
using namespace cp_algo::linalg;
1212

13-
const int mod = 998244353;
13+
const int64_t mod = 998244353;
1414
using base = modint<mod>;
1515
using polyn = poly_t<base>;
1616

verify/linalg/det.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using namespace std;
99
using namespace cp_algo::linalg;
1010
using namespace cp_algo::math;
1111

12-
const int mod = 998244353;
12+
const int64_t mod = 998244353;
1313

1414
void solve() {
1515
int n;

0 commit comments

Comments
 (0)