Skip to content

Commit a25f6fb

Browse files
committed
Refactor modint, try to introduce dynamic_modint
1 parent dc61819 commit a25f6fb

25 files changed

+191
-119
lines changed

cp-algo/algebra/all

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#ifndef CP_ALGO_ALGEBRA_ALL
22
#define CP_ALGO_ALGEBRA_ALL
3-
#include "modular.hpp"
3+
#include "modint.hpp"
44
#include "affine.hpp"
55
#include "common.hpp"
66
#include "poly.hpp"

cp-algo/algebra/fft.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef CP_ALGO_ALGEBRA_FFT_HPP
22
#define CP_ALGO_ALGEBRA_FFT_HPP
33
#include "common.hpp"
4-
#include "modular.hpp"
4+
#include "modint.hpp"
55
#include <algorithm>
66
#include <complex>
77
#include <cassert>
@@ -75,7 +75,7 @@ namespace cp_algo::algebra::fft {
7575
static constexpr int split = 1 << 15;
7676
std::vector<point> A;
7777

78-
dft(std::vector<modular<m>> const& a, size_t n): A(n) {
78+
dft(std::vector<modint<m>> const& a, size_t n): A(n) {
7979
for(size_t i = 0; i < std::min(n, a.size()); i++) {
8080
A[i] = point(
8181
a[i].rem() % split,
@@ -91,7 +91,7 @@ namespace cp_algo::algebra::fft {
9191
assert(A.size() == B.A.size());
9292
size_t n = A.size();
9393
if(!n) {
94-
return std::vector<modular<m>>();
94+
return std::vector<modint<m>>();
9595
}
9696
std::vector<point> C(n), D(n);
9797
for(size_t i = 0; i < n; i++) {
@@ -103,11 +103,11 @@ namespace cp_algo::algebra::fft {
103103
reverse(begin(C) + 1, end(C));
104104
reverse(begin(D) + 1, end(D));
105105
int t = 2 * n;
106-
std::vector<modular<m>> res(n);
106+
std::vector<modint<m>> res(n);
107107
for(size_t i = 0; i < n; i++) {
108-
modular<m> A0 = llround(C[i].real() / t);
109-
modular<m> A1 = llround(C[i].imag() / t + D[i].imag() / t);
110-
modular<m> A2 = llround(D[i].real() / t);
108+
modint<m> A0 = llround(C[i].real() / t);
109+
modint<m> A1 = llround(C[i].imag() / t + D[i].imag() / t);
110+
modint<m> A2 = llround(D[i].real() / t);
111111
res[i] = A0 + A1 * split - A2 * split * split;
112112
}
113113
return res;
@@ -129,7 +129,7 @@ namespace cp_algo::algebra::fft {
129129
}
130130

131131
template<int m>
132-
void mul(std::vector<modular<m>> &a, std::vector<modular<m>> b) {
132+
void mul(std::vector<modint<m>> &a, std::vector<modint<m>> b) {
133133
if(std::min(a.size(), b.size()) < magic) {
134134
mul_slow(a, b);
135135
return;

cp-algo/algebra/modint.hpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#ifndef CP_ALGO_ALGEBRA_MODINT_HPP
2+
#define CP_ALGO_ALGEBRA_MODINT_HPP
3+
#include "../random/rng.hpp"
4+
#include "affine.hpp"
5+
#include "common.hpp"
6+
#include <algorithm>
7+
#include <iostream>
8+
#include <optional>
9+
namespace cp_algo::algebra {
10+
template<typename modint>
11+
struct modint_base {
12+
// Would make it virtual, but it affects avx...
13+
int mod() const {
14+
if constexpr(modint::static_mod) {
15+
return modint::static_mod;
16+
} else {
17+
return static_cast<modint*>(this)->mod();
18+
}
19+
}
20+
modint_base(): r(0) {}
21+
modint_base(int64_t rr): r(rr % mod()) {
22+
r = std::min(r, r + mod());
23+
}
24+
modint inv() const {
25+
return bpow(static_cast<modint const&>(*this), mod() - 2);
26+
}
27+
modint operator - () const {return std::min(-r, mod() - r);}
28+
modint& operator /= (const modint &t) {return *this *= t.inv();}
29+
modint& operator *= (const modint &t) {
30+
r *= t.r; r %= mod();
31+
return static_cast<modint&>(*this);
32+
}
33+
modint& operator += (const modint &t) {
34+
r += t.r; r = std::min(r, r - mod());
35+
return static_cast<modint&>(*this);
36+
}
37+
modint& operator -= (const modint &t) {
38+
r -= t.r; r = std::min(r, r + mod());
39+
return static_cast<modint&>(*this);
40+
}
41+
modint operator + (const modint &t) const {
42+
return modint(static_cast<modint const&>(*this)) += t;
43+
}
44+
modint operator - (const modint &t) const {
45+
return modint(static_cast<modint const&>(*this)) -= t;
46+
}
47+
modint operator * (const modint &t) const {
48+
return modint(static_cast<modint const&>(*this)) *= t;
49+
}
50+
modint operator / (const modint &t) const {
51+
return modint(static_cast<modint const&>(*this)) /= t;
52+
}
53+
auto operator <=> (const modint_base &t) const = default;
54+
explicit operator int() const {return r;}
55+
int64_t rem() const {return 2 * r > (uint64_t)mod() ? r - mod() : r;}
56+
57+
// Only use if you really know what you're doing!
58+
uint64_t modmod() const {return 8LL * mod() * mod();};
59+
void add_unsafe(uint64_t t) {r += t;}
60+
void pseudonormalize() {r = std::min(r, r - modmod());}
61+
modint const& normalize() {
62+
if(r >= (uint64_t)mod()) {
63+
r %= mod();
64+
}
65+
return static_cast<modint&>(*this);
66+
}
67+
uint64_t& setr() {return r;}
68+
uint64_t getr() const {return r;}
69+
private:
70+
uint64_t r;
71+
};
72+
template<typename modint>
73+
std::istream& operator >> (std::istream &in, modint_base<modint> &x) {
74+
75+
return in >> x.setr();
76+
}
77+
template<typename modint>
78+
std::ostream& operator << (std::ostream &out, modint_base<modint> const& x) {
79+
return out << x.getr();
80+
}
81+
82+
template<int m>
83+
struct modint: modint_base<modint<m>> {
84+
constexpr static int static_mod = m;
85+
using Base = modint_base<modint<m>>;
86+
using Base::Base;
87+
constexpr static uint64_t static_modmod = 8LL*m*m;
88+
};
89+
90+
struct dynamic_modint: modint_base<dynamic_modint> {
91+
constexpr static int static_mod = 0;
92+
using Base = modint_base<dynamic_modint>;
93+
int mod() const {return m;}
94+
dynamic_modint(dynamic_modint const& t): Base(t), m(t.m) {}
95+
dynamic_modint(int m, int64_t r): m(m) {
96+
setr() = r % m;
97+
setr() = std::min(setr(), setr() + mod());
98+
}
99+
static auto GF(int mod) {
100+
return [mod](int64_t r) {
101+
return dynamic_modint(mod, r);
102+
};
103+
}
104+
private:
105+
int m;
106+
};
107+
}
108+
#endif // CP_ALGO_ALGEBRA_MODINT_HPP

cp-algo/algebra/modular.hpp

Lines changed: 0 additions & 69 deletions
This file was deleted.

cp-algo/algebra/number_theory.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#ifndef CP_ALGO_ALGEBRA_NUMBER_THEORY_HPP
2+
#define CP_ALGO_ALGEBRA_NUMBER_THEORY_HPP
3+
#include "../random/rng.hpp"
4+
#include "affine.hpp"
5+
#include "modint.hpp"
6+
#include <algorithm>
7+
#include <iostream>
8+
#include <optional>
9+
namespace cp_algo::algebra {
10+
// https://en.wikipedia.org/wiki/Berlekamp-Rabin_algorithm
11+
template<typename base>
12+
requires(std::is_base_of_v<modint_base<base>, base>)
13+
std::optional<base> sqrt(base b) {
14+
if(b == base(0)) {
15+
return base(0);
16+
} else if(bpow(b, (b.mod() - 1) / 2) != base(1)) {
17+
return std::nullopt;
18+
} else {
19+
while(true) {
20+
base z = random::rng();
21+
if(z * z == b) {
22+
return z;
23+
}
24+
lin<base> x(1, z, b); // x + z (mod x^2 - b)
25+
x = bpow(x, (b.mod() - 1) / 2, lin<base>(0, 1, b));
26+
if(x.a != base(0)) {
27+
return x.a.inv();
28+
}
29+
}
30+
}
31+
}
32+
}
33+
#endif // CP_ALGO_ALGEBRA_NUMBER_THEORY_HPP

cp-algo/algebra/poly.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "poly/impl/euclid.hpp"
44
#include "poly/impl/base.hpp"
55
#include "poly/impl/div.hpp"
6-
#include "common.hpp"
6+
#include "number_theory.hpp"
77
#include "fft.hpp"
88
#include <functional>
99
#include <algorithm>
@@ -277,7 +277,7 @@ namespace cp_algo::algebra {
277277
auto ans = div_xk(i).sqrt(n - i / 2);
278278
return ans ? ans->mul_xk(i / 2) : ans;
279279
}
280-
auto st = (*this)[0].sqrt();
280+
auto st = algebra::sqrt((*this)[0]);
281281
if(st) {
282282
poly_t ans = *st;
283283
size_t a = 1;

cp-algo/linalg/vector.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#ifndef CP_ALGO_LINALG_VECTOR_HPP
22
#define CP_ALGO_LINALG_VECTOR_HPP
3-
#include "../algebra/modular.hpp"
3+
#include "../algebra/modint.hpp"
44
#include "../random/rng.hpp"
55
#include <functional>
66
#include <algorithm>
@@ -15,10 +15,10 @@ namespace cp_algo::linalg {
1515

1616
valarray_base(base const& t): Base(t, 1) {}
1717

18-
auto begin() {return std::begin(*static_cast<Base*>(this));}
19-
auto end() {return std::end(*static_cast<Base*>(this));}
20-
auto begin() const {return std::begin(*static_cast<Base const*>(this));}
21-
auto end() const {return std::end(*static_cast<Base const*>(this));}
18+
auto begin() {return std::begin(static_cast<Base&>(*this));}
19+
auto end() {return std::end(static_cast<Base&>(*this));}
20+
auto begin() const {return std::begin(static_cast<Base const&>(*this));}
21+
auto end() const {return std::end(static_cast<Base const&>(*this));}
2222

2323
bool operator == (vec const& t) const {return std::ranges::equal(*this, t);}
2424
bool operator != (vec const& t) const {return !(*this == t);}
@@ -34,13 +34,13 @@ namespace cp_algo::linalg {
3434

3535
template<class vec, typename base>
3636
vec operator+(valarray_base<vec, base> const& a, valarray_base<vec, base> const& b) {
37-
return *static_cast<std::valarray<base> const*>(&a)
38-
+ *static_cast<std::valarray<base> const*>(&b);
37+
return static_cast<std::valarray<base> const&>(a)
38+
+ static_cast<std::valarray<base> const&>(b);
3939
}
4040
template<class vec, typename base>
4141
vec operator-(valarray_base<vec, base> const& a, valarray_base<vec, base> const& b) {
42-
return *static_cast<std::valarray<base> const*>(&a)
43-
- *static_cast<std::valarray<base> const*>(&b);
42+
return static_cast<std::valarray<base> const&>(a)
43+
- static_cast<std::valarray<base> const&>(b);
4444
}
4545

4646
template<class vec, typename base>
@@ -60,10 +60,10 @@ namespace cp_algo::linalg {
6060
(*this)[i] += scale * b[i];
6161
}
6262
}
63-
virtual vec& normalize() {
64-
return *static_cast<vec*>(this);
63+
virtual vec const& normalize() {
64+
return static_cast<vec&>(*this);
6565
}
66-
virtual base& normalize(size_t i) {
66+
virtual base normalize(size_t i) {
6767
return (*this)[i];
6868
}
6969
void read() {
@@ -120,15 +120,15 @@ namespace cp_algo::linalg {
120120
};
121121

122122
template<int mod>
123-
struct vec<algebra::modular<mod>>:
124-
vec_base<vec<algebra::modular<mod>>, algebra::modular<mod>> {
125-
using base = algebra::modular<mod>;
123+
struct vec<algebra::modint<mod>>:
124+
vec_base<vec<algebra::modint<mod>>, algebra::modint<mod>> {
125+
using base = algebra::modint<mod>;
126126
using Base = vec_base<vec<base>, base>;
127127
using Base::Base;
128128

129129
void add_scaled(vec const& b, base scale, size_t i = 0) override {
130130
for(; i < size(*this); i++) {
131-
(*this)[i].add_unsafe(scale.r * b[i].r);
131+
(*this)[i].add_unsafe(scale.getr() * b[i].getr());
132132
}
133133
if(++counter == 8) {
134134
for(auto &it: *this) {
@@ -137,13 +137,13 @@ namespace cp_algo::linalg {
137137
counter = 0;
138138
}
139139
}
140-
vec& normalize() override {
140+
vec const& normalize() override {
141141
for(auto &it: *this) {
142142
it.normalize();
143143
}
144144
return *this;
145145
}
146-
base& normalize(size_t i) override {
146+
base normalize(size_t i) override {
147147
return (*this)[i].normalize();
148148
}
149149
private:

verify/algebra/matrix/characteristic.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using namespace cp_algo::algebra;
1111
using namespace cp_algo::linalg;
1212

1313
const int mod = 998244353;
14-
using base = modular<mod>;
14+
using base = modint<mod>;
1515
using polyn = poly_t<base>;
1616

1717
void solve() {

verify/algebra/matrix/det.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const int mod = 998244353;
1414
void solve() {
1515
int n;
1616
cin >> n;
17-
matrix<modular<mod>> a(n, n);
17+
matrix<modint<mod>> a(n, n);
1818
a.read();
1919
cout << a.det() << "\n";
2020
}

0 commit comments

Comments
 (0)