Skip to content

Commit 9ed82eb

Browse files
committed
Introduce modint_type concept, fix fft
1 parent f27f75f commit 9ed82eb

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

cp-algo/algebra/fft.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ namespace cp_algo::algebra::fft {
7070
}
7171
}
7272

73-
template<int m>
73+
template<modint_type base>
7474
struct dft {
7575
static constexpr int split = 1 << 15;
7676
std::vector<point> A;
7777

78-
dft(std::vector<modint<m>> const& a, size_t n): A(n) {
78+
dft(std::vector<base> const& a, size_t n): A(n) {
7979
for(size_t i = 0; i < std::min(n, a.size()); i++) {
8080
A[i] = point(
8181
a[i].rem() % split,
@@ -91,7 +91,7 @@ namespace cp_algo::algebra::fft {
9191
assert(A.size() == B.A.size());
9292
size_t n = A.size();
9393
if(!n) {
94-
return std::vector<modint<m>>();
94+
return std::vector<base>();
9595
}
9696
std::vector<point> C(n), D(n);
9797
for(size_t i = 0; i < n; i++) {
@@ -103,11 +103,11 @@ namespace cp_algo::algebra::fft {
103103
reverse(begin(C) + 1, end(C));
104104
reverse(begin(D) + 1, end(D));
105105
int t = 2 * n;
106-
std::vector<modint<m>> res(n);
106+
std::vector<base> res(n);
107107
for(size_t i = 0; i < n; i++) {
108-
modint<m> A0 = llround(C[i].real() / t);
109-
modint<m> A1 = llround(C[i].imag() / t + D[i].imag() / t);
110-
modint<m> A2 = llround(D[i].real() / t);
108+
base A0 = llround(C[i].real() / t);
109+
base A1 = llround(C[i].imag() / t + D[i].imag() / t);
110+
base A2 = llround(D[i].real() / t);
111111
res[i] = A0 + A1 * split - A2 * split * split;
112112
}
113113
return res;
@@ -128,18 +128,18 @@ namespace cp_algo::algebra::fft {
128128
return n;
129129
}
130130

131-
template<int m>
132-
void mul(std::vector<modint<m>> &a, std::vector<modint<m>> b) {
131+
template<modint_type base>
132+
void mul(std::vector<base> &a, std::vector<base> const& b) {
133133
if(std::min(a.size(), b.size()) < magic) {
134134
mul_slow(a, b);
135135
return;
136136
}
137137
auto n = com_size(a.size(), b.size());
138-
auto A = dft<m>(a, n);
138+
auto A = dft<base>(a, n);
139139
if(a == b) {
140140
a = A * A;
141141
} else {
142-
a = A * dft<m>(b, n);
142+
a = A * dft<base>(b, n);
143143
}
144144
}
145145
}

cp-algo/algebra/modint.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ namespace cp_algo::algebra {
6868
return out << x.getr();
6969
}
7070

71+
template<typename modint>
72+
concept modint_type = std::is_base_of_v<modint_base<modint>, modint>;
73+
7174
template<int64_t m>
7275
struct modint: modint_base<modint<m>> {
7376
static constexpr int64_t mod() {return m;}

cp-algo/algebra/number_theory.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
#include <optional>
77
namespace cp_algo::algebra {
88
// https://en.wikipedia.org/wiki/Berlekamp-Rabin_algorithm
9-
template<typename base>
10-
requires(std::is_base_of_v<modint_base<base>, base>)
9+
template<modint_type base>
1110
std::optional<base> sqrt(base b) {
1211
if(b == base(0)) {
1312
return base(0);
@@ -28,8 +27,7 @@ namespace cp_algo::algebra {
2827
}
2928
}
3029

31-
template<typename base>
32-
requires(std::is_base_of_v<modint_base<base>, base>)
30+
template<modint_type base>
3331
bool is_prime_mod() {
3432
auto m = base::mod();
3533
if(m == 1 || m % 2 == 0) {

0 commit comments

Comments
 (0)