7
7
#include < vector>
8
8
#include < map>
9
9
namespace cp_algo ::math {
10
- using gaussint = complex <int64_t >;
11
- gaussint two_squares_prime_any (int64_t p) {
10
+ template <typename T>
11
+ using gaussint = complex <T>;
12
+ template <typename _Int>
13
+ auto two_squares_prime_any (_Int p) {
12
14
if (p == 2 ) {
13
- return gaussint ( 1 , 1 ) ;
15
+ return gaussint<_Int>{ 1 , 1 } ;
14
16
}
15
17
assert (p % 4 == 1 );
16
- using base = dynamic_modint<>;
18
+ using Int = std::make_signed_t <_Int>;
19
+ using base = dynamic_modint<Int>;
17
20
return base::with_mod (p, [&](){
18
21
base g = primitive_root (p);
19
22
int64_t i = bpow (g, (p - 1 ) / 4 ).getr ();
@@ -25,49 +28,50 @@ namespace cp_algo::math {
25
28
q0 = std::exchange (q1, q0 + d * q1);
26
29
r = std::exchange (m, r % m);
27
30
} while (q1 < p / q1);
28
- return gaussint ( q0, (base (i) * base (q0)).rem ()) ;
31
+ return gaussint<_Int>{ q0, (base (i) * base (q0)).rem ()} ;
29
32
});
30
33
}
31
34
32
- std::vector<gaussint> two_squares_all (int64_t n) {
35
+ template <typename Int>
36
+ std::vector<gaussint<Int>> two_squares_all (Int n) {
33
37
if (n == 0 ) {
34
38
return {0 };
35
39
}
36
40
auto primes = factorize (n);
37
- std::map<int64_t , int > cnt;
41
+ std::map<Int , int > cnt;
38
42
for (auto p: primes) {
39
43
cnt[p]++;
40
44
}
41
- std::vector<gaussint> res = {1 };
45
+ std::vector<gaussint<Int> > res = {1 };
42
46
for (auto [p, c]: cnt) {
43
- std::vector<gaussint> nres;
47
+ std::vector<gaussint<Int> > nres;
44
48
if (p % 4 == 3 ) {
45
49
if (c % 2 == 0 ) {
46
- auto mul = bpow (gaussint (p), c / 2 );
50
+ auto mul = bpow (gaussint<Int> (p), c / 2 );
47
51
for (auto p: res) {
48
52
nres.push_back (p * mul);
49
53
}
50
54
}
51
55
} else if (p % 4 == 1 ) {
52
- gaussint base = two_squares_prime_any (p);
56
+ auto base = two_squares_prime_any (p);
53
57
for (int i = 0 ; i <= c; i++) {
54
58
auto mul = bpow (base, i) * bpow (conj (base), c - i);
55
59
for (auto p: res) {
56
60
nres.push_back (p * mul);
57
61
}
58
62
}
59
63
} else if (p % 4 == 2 ) {
60
- auto mul = bpow (gaussint (1 , 1 ), c);
64
+ auto mul = bpow (gaussint<Int> (1 , 1 ), c);
61
65
for (auto p: res) {
62
66
nres.push_back (p * mul);
63
67
}
64
68
}
65
69
res = nres;
66
70
}
67
- std::vector<gaussint> nres;
71
+ std::vector<gaussint<Int> > nres;
68
72
for (auto p: res) {
69
73
while (p.real () < 0 || p.imag () < 0 ) {
70
- p *= gaussint (0 , 1 );
74
+ p *= gaussint<Int> (0 , 1 );
71
75
}
72
76
nres.push_back (p);
73
77
if (!p.real () || !p.imag ()) {
0 commit comments