@@ -70,12 +70,12 @@ namespace cp_algo::algebra::fft {
70
70
}
71
71
}
72
72
73
- template <int m >
73
+ template <modint_type base >
74
74
struct dft {
75
75
static constexpr int split = 1 << 15 ;
76
76
std::vector<point> A;
77
77
78
- dft (std::vector<modint<m> > const & a, size_t n): A(n) {
78
+ dft (std::vector<base > const & a, size_t n): A(n) {
79
79
for (size_t i = 0 ; i < std::min (n, a.size ()); i++) {
80
80
A[i] = point (
81
81
a[i].rem () % split,
@@ -91,7 +91,7 @@ namespace cp_algo::algebra::fft {
91
91
assert (A.size () == B.A .size ());
92
92
size_t n = A.size ();
93
93
if (!n) {
94
- return std::vector<modint<m> >();
94
+ return std::vector<base >();
95
95
}
96
96
std::vector<point> C (n), D (n);
97
97
for (size_t i = 0 ; i < n; i++) {
@@ -103,11 +103,11 @@ namespace cp_algo::algebra::fft {
103
103
reverse (begin (C) + 1 , end (C));
104
104
reverse (begin (D) + 1 , end (D));
105
105
int t = 2 * n;
106
- std::vector<modint<m> > res (n);
106
+ std::vector<base > res (n);
107
107
for (size_t i = 0 ; i < n; i++) {
108
- modint<m> A0 = llround (C[i].real () / t);
109
- modint<m> A1 = llround (C[i].imag () / t + D[i].imag () / t);
110
- modint<m> A2 = llround (D[i].real () / t);
108
+ base A0 = llround (C[i].real () / t);
109
+ base A1 = llround (C[i].imag () / t + D[i].imag () / t);
110
+ base A2 = llround (D[i].real () / t);
111
111
res[i] = A0 + A1 * split - A2 * split * split;
112
112
}
113
113
return res;
@@ -128,18 +128,18 @@ namespace cp_algo::algebra::fft {
128
128
return n;
129
129
}
130
130
131
- template <int m >
132
- void mul (std::vector<modint<m>> &a, std::vector<modint<m>> b) {
131
+ template <modint_type base >
132
+ void mul (std::vector<base> &a, std::vector<base> const & b) {
133
133
if (std::min (a.size (), b.size ()) < magic) {
134
134
mul_slow (a, b);
135
135
return ;
136
136
}
137
137
auto n = com_size (a.size (), b.size ());
138
- auto A = dft<m >(a, n);
138
+ auto A = dft<base >(a, n);
139
139
if (a == b) {
140
140
a = A * A;
141
141
} else {
142
- a = A * dft<m >(b, n);
142
+ a = A * dft<base >(b, n);
143
143
}
144
144
}
145
145
}
0 commit comments