Skip to content

Commit f27f75f

Browse files
committed
Primality test + support large mod
1 parent f55ce90 commit f27f75f

File tree

3 files changed

+87
-10
lines changed

3 files changed

+87
-10
lines changed

cp-algo/algebra/modint.hpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace cp_algo::algebra {
66
template<typename modint>
77
struct modint_base {
8-
static int mod() {
8+
static int64_t mod() {
99
return modint::mod();
1010
}
1111
modint_base(): r(0) {}
@@ -20,7 +20,11 @@ namespace cp_algo::algebra {
2020
return to_modint() *= t.inv();
2121
}
2222
modint& operator *= (const modint &t) {
23-
r *= t.r; if(mod()) {r %= mod();}
23+
if(mod() <= uint32_t(-1)) {
24+
r = r * t.r % mod();
25+
} else {
26+
r = __int128(r) * t.r % mod();
27+
}
2428
return to_modint();
2529
}
2630
modint& operator += (const modint &t) {
@@ -36,11 +40,10 @@ namespace cp_algo::algebra {
3640
modint operator * (const modint &t) const {return modint(to_modint()) *= t;}
3741
modint operator / (const modint &t) const {return modint(to_modint()) /= t;}
3842
auto operator <=> (const modint_base &t) const = default;
39-
explicit operator int() const {return r;}
4043
int64_t rem() const {return 2 * r > (uint64_t)mod() ? r - mod() : r;}
4144

4245
// Only use if you really know what you're doing!
43-
uint64_t modmod() const {return 8LL * mod() * mod();};
46+
uint64_t modmod() const {return 8ULL * mod() * mod();};
4447
void add_unsafe(uint64_t t) {r += t;}
4548
void pseudonormalize() {r = std::min(r, r - modmod());}
4649
modint const& normalize() {
@@ -65,21 +68,30 @@ namespace cp_algo::algebra {
6568
return out << x.getr();
6669
}
6770

68-
template<int m>
71+
template<int64_t m>
6972
struct modint: modint_base<modint<m>> {
70-
static constexpr int mod() {return m;}
73+
static constexpr int64_t mod() {return m;}
7174
using Base = modint_base<modint<m>>;
7275
using Base::Base;
7376
};
7477

7578
struct dynamic_modint: modint_base<dynamic_modint> {
76-
static int mod() {return m;}
77-
static void switch_mod(int nm) {m = nm;}
79+
static int64_t mod() {return m;}
80+
static void switch_mod(int64_t nm) {m = nm;}
7881
using Base = modint_base<dynamic_modint>;
7982
using Base::Base;
83+
84+
// Wrapper for temp switching
85+
auto static with_switched_mod(int64_t tmp, auto callback) {
86+
auto prev = mod();
87+
switch_mod(tmp);
88+
auto res = callback();
89+
switch_mod(prev);
90+
return res;
91+
}
8092
private:
81-
static int m;
93+
static int64_t m;
8294
};
83-
int dynamic_modint::m = 0;
95+
int64_t dynamic_modint::m = 0;
8496
}
8597
#endif // CP_ALGO_ALGEBRA_MODINT_HPP

cp-algo/algebra/number_theory.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,44 @@ namespace cp_algo::algebra {
2727
}
2828
}
2929
}
30+
31+
template<typename base>
32+
requires(std::is_base_of_v<modint_base<base>, base>)
33+
bool is_prime_mod() {
34+
auto m = base::mod();
35+
if(m == 1 || m % 2 == 0) {
36+
return m == 2;
37+
}
38+
auto m1 = m - 1;
39+
int d = 0;
40+
while(m1 % 2 == 0) {
41+
m1 /= 2;
42+
d++;
43+
}
44+
auto test = [&](auto x) {
45+
x = bpow(x, m1);
46+
if(x == 0 || x == 1 || x == -1) {
47+
return true;
48+
}
49+
for(int i = 0; i <= d; i++) {
50+
if(x == -1) {
51+
return true;
52+
}
53+
x *= x;
54+
}
55+
return false;
56+
};
57+
for(base b: {2, 325, 9375, 28178, 450775, 9780504, 1795265022}) {
58+
if(!test(b)) {
59+
return false;
60+
}
61+
}
62+
return true;
63+
}
64+
bool is_prime(int64_t m) {
65+
return dynamic_modint::with_switched_mod(m, [](){
66+
return is_prime_mod<dynamic_modint>();
67+
});
68+
}
3069
}
3170
#endif // CP_ALGO_ALGEBRA_NUMBER_THEORY_HPP
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// @brief Primality Test
2+
#define PROBLEM "https://judge.yosupo.jp/problem/primality_test"
3+
#pragma GCC optimize("Ofast,unroll-loops")
4+
#pragma GCC target("avx2,tune=native")
5+
#include "cp-algo/algebra/number_theory.hpp"
6+
#include <bits/stdc++.h>
7+
8+
using namespace std;
9+
using namespace cp_algo::algebra;
10+
11+
void solve() {
12+
int64_t m;
13+
cin >> m;
14+
cout << (is_prime(m) ? "Yes" : "No") << "\n";
15+
}
16+
17+
signed main() {
18+
//freopen("input.txt", "r", stdin);
19+
ios::sync_with_stdio(0);
20+
cin.tie(0);
21+
int t = 1;
22+
cin >> t;
23+
while(t--) {
24+
solve();
25+
}
26+
}

0 commit comments

Comments
 (0)