Skip to content

Commit b671efe

Browse files
committed
Use montgomery in modint
1 parent ba1ce4b commit b671efe

File tree

10 files changed

+80
-28
lines changed

10 files changed

+80
-28
lines changed

cp-algo/linalg/vector.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,10 @@ namespace cp_algo::linalg {
127127
using Base::Base;
128128

129129
void add_scaled(vec const& b, base scale, size_t i = 0) override {
130+
uint64_t scaler = scale.getr();
130131
if(scale != base(0)) {
131132
for(; i < size(*this); i++) {
132-
(*this)[i].add_unsafe(scale.getr() * b[i].getr());
133+
(*this)[i].add_unsafe(scaler * b[i].getr_direct());
133134
}
134135
if(++counter == 8) {
135136
for(auto &it: *this) {

cp-algo/math/fft.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef CP_ALGO_MATH_FFT_HPP
22
#define CP_ALGO_MATH_FFT_HPP
33
#include "common.hpp"
4-
#include "cp-algo/number_theory/modint.hpp"
4+
#include "../number_theory/modint.hpp"
55
#include <algorithm>
66
#include <complex>
77
#include <cassert>

cp-algo/math/poly.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "poly/impl/base.hpp"
55
#include "poly/impl/div.hpp"
66
#include "combinatorics.hpp"
7-
#include "cp-algo/number_theory/discrete_sqrt.hpp"
7+
#include "../number_theory/discrete_sqrt.hpp"
88
#include "fft.hpp"
99
#include <functional>
1010
#include <algorithm>

cp-algo/number_theory/discrete_log.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ namespace cp_algo::math {
2626
for(size_t k = 0; k < m; k += sqrtmod) {
2727
auto it = small.find((base(c) * cur).getr());
2828
if(it != end(small)) {
29-
auto cand = base::with_mod(period(base(b)), [&](){
30-
return base(it->second - k);
31-
}).getr();
29+
auto cand = base::with_mod(period(base(b)), [&]() {
30+
return base(it->second - k).getr();
31+
});
3232
if(base(a) * bpow(base(b), cand) == base(c)) {
3333
return cand;
3434
} else {

cp-algo/number_theory/discrete_sqrt.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#ifndef CP_ALGO_NUMBER_THEORY_DISCRETE_SQRT_HPP
22
#define CP_ALGO_NUMBER_THEORY_DISCRETE_SQRT_HPP
33
#include "modint.hpp"
4-
#include "cp-algo/random/rng.hpp"
5-
#include "cp-algo/math/affine.hpp"
4+
#include "../random/rng.hpp"
5+
#include "../math/affine.hpp"
66
namespace cp_algo::math {
77
// https://en.wikipedia.org/wiki/Berlekamp-Rabin_algorithm
88
template<modint_type base>

cp-algo/number_theory/factorize.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef CP_ALGO_MATH_FACTORIZE_HPP
22
#define CP_ALGO_MATH_FACTORIZE_HPP
33
#include "primality.hpp"
4-
#include "cp-algo/random/rng.hpp"
4+
#include "../random/rng.hpp"
55
namespace cp_algo::math {
66
// https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm
77
void factorize(uint64_t m, std::vector<int64_t> &res) {

cp-algo/number_theory/modint.hpp

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,63 @@
11
#ifndef CP_ALGO_NUMBER_THEORY_MODINT_HPP
22
#define CP_ALGO_NUMBER_THEORY_MODINT_HPP
3-
#include "cp-algo/math/common.hpp"
3+
#include "../math/common.hpp"
44
#include <iostream>
5+
#include <cassert>
56
namespace cp_algo::math {
7+
inline constexpr uint64_t inv64(uint64_t x) {
8+
assert(x % 2);
9+
uint64_t y = 1;
10+
while(y * x != 1) {
11+
y *= 2 - x * y;
12+
}
13+
return y;
14+
}
15+
616
template<typename modint>
717
struct modint_base {
818
static int64_t mod() {
919
return modint::mod();
1020
}
21+
static uint64_t imod() {
22+
return modint::imod();
23+
}
24+
static __uint128_t pw128() {
25+
return modint::pw128();
26+
}
27+
static uint64_t m_reduce(__uint128_t ab) {
28+
if(mod() % 2 == 0) {
29+
return ab % mod();
30+
} else {
31+
uint64_t m = ab * imod();
32+
int64_t res = (ab + __uint128_t(m) * mod()) >> 64;
33+
return res < mod() ? res : res - mod();
34+
}
35+
}
36+
static uint64_t m_transform(uint64_t a) {
37+
if(mod() % 2 == 0) {
38+
return a;
39+
} else {
40+
return m_reduce(a * pw128());
41+
}
42+
}
1143
modint_base(): r(0) {}
1244
modint_base(int64_t rr): r(rr % mod()) {
1345
r = std::min(r, r + mod());
46+
r = m_transform(r);
1447
}
1548
modint inv() const {
1649
return bpow(to_modint(), mod() - 2);
1750
}
18-
modint operator - () const {return std::min(-r, mod() - r);}
51+
modint operator - () const {
52+
modint neg;
53+
neg.r = std::min(-r, mod() - r);
54+
return neg;
55+
}
1956
modint& operator /= (const modint &t) {
2057
return to_modint() *= t.inv();
2158
}
2259
modint& operator *= (const modint &t) {
23-
if(mod() <= uint32_t(-1)) {
24-
r = r * t.r % mod();
25-
} else {
26-
r = __int128(r) * t.r % mod();
27-
}
60+
r = m_reduce(__uint128_t(r) * t.r);
2861
return to_modint();
2962
}
3063
modint& operator += (const modint &t) {
@@ -40,7 +73,10 @@ namespace cp_algo::math {
4073
modint operator * (const modint &t) const {return modint(to_modint()) *= t;}
4174
modint operator / (const modint &t) const {return modint(to_modint()) /= t;}
4275
auto operator <=> (const modint_base &t) const = default;
43-
int64_t rem() const {return 2 * r > (uint64_t)mod() ? r - mod() : r;}
76+
int64_t rem() const {
77+
uint64_t R = getr();
78+
return 2 * R > (uint64_t)mod() ? R - mod() : R;
79+
}
4480

4581
// Only use if you really know what you're doing!
4682
uint64_t modmod() const {return 8ULL * mod() * mod();};
@@ -52,16 +88,21 @@ namespace cp_algo::math {
5288
}
5389
return to_modint();
5490
}
55-
uint64_t& setr() {return r;}
56-
uint64_t getr() const {return r;}
91+
void setr(uint64_t rr) {r = m_transform(rr);}
92+
uint64_t getr() const {return m_reduce(r);}
93+
void setr_direct(uint64_t rr) {r = rr;}
94+
uint64_t getr_direct() const {return r;}
5795
private:
5896
uint64_t r;
5997
modint& to_modint() {return static_cast<modint&>(*this);}
6098
modint const& to_modint() const {return static_cast<modint const&>(*this);}
6199
};
62100
template<typename modint>
63101
std::istream& operator >> (std::istream &in, modint_base<modint> &x) {
64-
return in >> x.setr();
102+
uint64_t r;
103+
auto &res = in >> r;
104+
x.setr(r);
105+
return res;
65106
}
66107
template<typename modint>
67108
std::ostream& operator << (std::ostream &out, modint_base<modint> const& x) {
@@ -73,14 +114,24 @@ namespace cp_algo::math {
73114

74115
template<int64_t m>
75116
struct modint: modint_base<modint<m>> {
117+
static constexpr uint64_t im = m % 2 ? inv64(-m) : 0;
118+
static constexpr uint64_t r2 = __uint128_t(-1) % m + 1;
76119
static constexpr int64_t mod() {return m;}
120+
static constexpr uint64_t imod() {return im;}
121+
static constexpr __uint128_t pw128() {return r2;}
77122
using Base = modint_base<modint<m>>;
78123
using Base::Base;
79124
};
80125

81126
struct dynamic_modint: modint_base<dynamic_modint> {
82127
static int64_t mod() {return m;}
83-
static void switch_mod(int64_t nm) {m = nm;}
128+
static uint64_t imod() {return im;}
129+
static __uint128_t pw128() {return r2;}
130+
static void switch_mod(int64_t nm) {
131+
m = nm;
132+
im = m % 2 ? inv64(-m) : 0;
133+
r2 = __uint128_t(-1) % m + 1;
134+
}
84135
using Base = modint_base<dynamic_modint>;
85136
using Base::Base;
86137

@@ -95,7 +146,10 @@ namespace cp_algo::math {
95146
}
96147
private:
97148
static int64_t m;
149+
static uint64_t im, r1, r2;
98150
};
99-
int64_t dynamic_modint::m = 0;
151+
int64_t dynamic_modint::m = 1;
152+
uint64_t dynamic_modint::im = -1;
153+
uint64_t dynamic_modint::r2 = 0;
100154
}
101155
#endif // CP_ALGO_MATH_MODINT_HPP

verify/linalg/inv.test.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// @brief Inverse Matrix
22
#define PROBLEM "https://judge.yosupo.jp/problem/inverse_matrix"
33
#pragma GCC optimize("Ofast,unroll-loops")
4-
#pragma GCC target("tune=native")
54
#include "cp-algo/linalg/matrix.hpp"
65
#include <bits/stdc++.h>
76

@@ -32,4 +31,4 @@ signed main() {
3231
while(t--) {
3332
solve();
3433
}
35-
}
34+
}

verify/linalg/prod.test.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// @brief Matrix Product
22
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
33
#pragma GCC optimize("Ofast,unroll-loops")
4-
#pragma GCC target("tune=native")
54
#include "cp-algo/linalg/matrix.hpp"
65
#include <bits/stdc++.h>
76

@@ -29,4 +28,4 @@ signed main() {
2928
while(t--) {
3029
solve();
3130
}
32-
}
31+
}

verify/number_theory/discrete_log.test.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// @brief Discrete Logarithm
22
#define PROBLEM "https://judge.yosupo.jp/problem/discrete_logarithm_mod"
33
#pragma GCC optimize("Ofast,unroll-loops")
4-
#pragma GCC target("tune=native")
54
#include "cp-algo/number_theory/discrete_log.hpp"
65
#include <bits/stdc++.h>
76

@@ -30,4 +29,4 @@ signed main() {
3029
while(t--) {
3130
solve();
3231
}
33-
}
32+
}

0 commit comments

Comments
 (0)