Skip to content

Commit 0440f37

Browse files
committed
Add superfast matrix pow test
1 parent 28ae587 commit 0440f37

File tree

7 files changed

+132
-64
lines changed

7 files changed

+132
-64
lines changed

cp-algo/algebra/poly.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace cp_algo::algebra {
3333
poly_t mod_xk(size_t k) const {return poly::impl::mod_xk(*this, k);} // %= x^k
3434
poly_t mul_xk(size_t k) const {return poly::impl::mul_xk(*this, k);} // *= x^k
3535
poly_t div_xk(size_t k) const {return poly::impl::div_xk(*this, k);} // /= x^k
36-
poly_t substr(size_t l, size_t r) const {return poly::impl::substr(*this, l, r);}
36+
poly_t substr(size_t l, size_t k) const {return poly::impl::substr(*this, l, k);}
3737

3838
poly_t operator *= (const poly_t &t) {fft::mul(a, t.a); normalize(); return *this;}
3939
poly_t operator * (const poly_t &t) const {return poly_t(*this) *= t;}

cp-algo/algebra/poly/impl/base.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ namespace cp_algo::algebra::poly::impl {
5151
}
5252
return std::vector(begin(p.a) + std::min<size_t>(k, p.a.size()), end(p.a));
5353
}
54-
auto substr(auto const& p, size_t l, size_t r) {
54+
auto substr(auto const& p, size_t l, size_t k) {
5555
return std::vector(
5656
begin(p.a) + std::min(l, p.a.size()),
57-
begin(p.a) + std::min(r, p.a.size())
57+
begin(p.a) + std::min(l + k, p.a.size())
5858
);
5959
}
6060
auto reverse(auto p, size_t n) {

cp-algo/linalg/frobenius.hpp

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,74 @@
33
#include "matrix.hpp"
44
#include "../algebra/poly.hpp"
55
#include <vector>
6-
namespace cp_algo::linalg {
7-
template<bool reduce = false>
8-
auto frobenius_basis(auto const& A) {
6+
namespace cp_algo::linalg {
7+
enum frobenius_mode {blocks, full};
8+
template<frobenius_mode mode = blocks>
9+
auto frobenius_form(auto const& A) {
910
using matrix = std::decay_t<decltype(A)>;
1011
using base = matrix::base;
1112
using polyn = algebra::poly_t<base>;
1213
assert(A.n() == A.m());
1314
size_t n = A.n();
14-
struct krylov {
15-
std::vector<vec<base>> basis;
16-
polyn rec;
17-
};
18-
std::vector<krylov> blocks;
19-
std::vector<vec<base>> reduced;
20-
while(size(reduced) < n) {
15+
std::vector<polyn> charps;
16+
std::vector<vec<base>> basis, basis_init;
17+
while(size(basis) < n) {
18+
size_t start = size(basis);
2119
auto generate_block = [&](auto x) {
22-
krylov block;
2320
while(true) {
24-
vec<base> y = x | vec<base>::ei(n + 1, size(reduced));
25-
for(auto &it: reduced) {
21+
vec<base> y = x | vec<base>::ei(n + 1, size(basis));
22+
for(auto &it: basis) {
2623
y.reduce_by(it);
2724
}
2825
y.normalize();
2926
if(vec<base>(y[std::slice(0, n, 1)]) == vec<base>(n)) {
30-
block.rec = std::vector<base>(
31-
begin(y) + n + size(reduced) - size(block.basis),
32-
begin(y) + n + size(reduced) + 1
33-
);
34-
return std::pair{block, vec<base>(y[std::slice(n, n, 1)])};
27+
return polyn(std::vector<base>(begin(y) + n, end(y)));
3528
} else {
36-
block.basis.push_back(x);
37-
reduced.push_back(y);
29+
basis_init.push_back(x);
30+
basis.push_back(y);
3831
x = A.apply(x);
3932
}
4033
}
4134
};
42-
auto [block, full_rec] = generate_block(vec<base>::random(n));
43-
if constexpr (reduce) {
44-
if(vec<base>(full_rec[std::slice(0, size(reduced), 1)]) != vec<base>(size(reduced))) {
45-
auto x = block.basis[0];
46-
size_t start = 0;
47-
for(auto &[basis, rec]: blocks) {
48-
polyn cur_rec = std::vector<base>(
49-
begin(full_rec) + start, begin(full_rec) + start + rec.deg()
50-
);
51-
auto shift = cur_rec / block.rec;
35+
auto full_rec = generate_block(vec<base>::random(n));
36+
// Extra trimming to make it block-diagonal (expensive)
37+
if constexpr (mode == full) {
38+
if(full_rec.mod_xk(start) != polyn()) {
39+
auto charp = full_rec.div_xk(start);
40+
auto x = basis_init[start];
41+
start = 0;
42+
for(auto &rec: charps) {
43+
polyn cur_rec = full_rec.substr(start, rec.deg());
44+
auto shift = cur_rec / charp;
5245
for(int j = 0; j <= shift.deg(); j++) {
53-
x.add_scaled(basis[j], shift[j]);
46+
x.add_scaled(basis_init[start + j], shift[j]);
5447
}
5548
start += rec.deg();
5649
}
57-
reduced.erase(begin(reduced) + start, end(reduced));
58-
tie(block, full_rec) = generate_block(x.normalize());
50+
basis.resize(start);
51+
basis_init.resize(start);
52+
full_rec = generate_block(x.normalize());
5953
}
6054
}
61-
blocks.push_back(block);
55+
charps.push_back(full_rec.div_xk(start));
56+
}
57+
// Find transform matrices while we're at it...
58+
if constexpr (mode == full) {
59+
for(size_t i = 0; i < size(basis); i++) {
60+
for(size_t j = i + 1; j < size(basis); j++) {
61+
basis[i].reduce_by(basis[j]);
62+
}
63+
basis[i].normalize();
64+
basis[i] = vec<base>(
65+
basis[i][std::slice(n, n, 1)]
66+
) * (base(1) / basis[i][i]);
67+
}
68+
auto T = matrix::from_range(basis_init);
69+
auto Tinv = matrix::from_range(basis);
70+
return std::tuple{T, Tinv, charps};
71+
} else {
72+
return charps;
6273
}
63-
return blocks;
6474
}
6575
};
66-
#endif // CP_ALGO_LINALG_FROBENIUS_HPP
76+
#endif // CP_ALGO_LINALG_FROBENIUS_HPP

cp-algo/linalg/matrix.hpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,30 @@ namespace cp_algo::linalg {
4242
}
4343
}
4444

45+
static matrix block_diagonal(std::vector<matrix> const& blocks) {
46+
size_t n = 0;
47+
for(auto &it: blocks) {
48+
assert(it.n() == it.m());
49+
n += it.n();
50+
}
51+
matrix res(n);
52+
n = 0;
53+
for(auto &it: blocks) {
54+
for(size_t i = 0; i < it.n(); i++) {
55+
res[n + i][std::slice(n, it.n(), 1)] = it[i];
56+
}
57+
n += it.n();
58+
}
59+
return res;
60+
}
61+
static matrix random(size_t n, size_t m) {
62+
matrix res(n, m);
63+
std::ranges::generate(res, std::bind(vec<base>::random, m));
64+
return res;
65+
}
66+
static matrix random(size_t n) {
67+
return random(n, n);
68+
}
4569
static matrix eye(size_t n) {
4670
matrix res(n);
4771
for(size_t i = 0; i < n; i++) {
@@ -97,24 +121,15 @@ namespace cp_algo::linalg {
97121
return bpow(*this, k, eye(n()));
98122
}
99123

100-
static matrix random(size_t n, size_t m) {
101-
matrix res(n, m);
102-
std::ranges::generate(res, std::bind(vec<base>::random, m));
103-
return res;
104-
}
105-
static matrix random(size_t n) {
106-
return random(n, n);
107-
}
108-
109124
matrix& normalize() {
110125
for(auto &it: *this) {
111126
it.normalize();
112127
}
113128
return *this;
114129
}
115130

116-
enum Mode {normal, reverse};
117-
template<Mode mode = normal>
131+
enum gauss_mode {normal, reverse};
132+
template<gauss_mode mode = normal>
118133
matrix& gauss() {
119134
for(size_t i = 0; i < n(); i++) {
120135
row(i).normalize();
@@ -126,11 +141,11 @@ namespace cp_algo::linalg {
126141
}
127142
return normalize();
128143
}
129-
template<Mode mode = normal>
144+
template<gauss_mode mode = normal>
130145
auto echelonize(size_t lim) {
131146
return gauss<mode>().sort_classify(lim);
132147
}
133-
template<Mode mode = normal>
148+
template<gauss_mode mode = normal>
134149
auto echelonize() {
135150
return echelonize<mode>(m());
136151
}

cp-algo/linalg/vector.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ namespace cp_algo::linalg {
2424
bool operator != (vec const& t) const {return !(*this == t);}
2525

2626
vec operator-() const {return Base::operator-();}
27-
vec operator-(vec const& t) const {return Base::operator-(t);}
28-
vec operator+(vec const& t) const {return Base::operator+(t);}
2927

3028
static vec from_range(auto const& R) {
3129
vec res(std::ranges::distance(R));
@@ -45,9 +43,6 @@ namespace cp_algo::linalg {
4543
return res;
4644
}
4745

48-
// Make sure the result is vec, not Base
49-
vec operator*(base t) const {return Base::operator*(t);}
50-
5146
void add_scaled(vec const& b, base scale, size_t i = 0) {
5247
assert(false);
5348
for(; i < size(*this); i++) {

verify/algebra/matrix/characteristic.test.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@ void solve() {
1818
cin >> n;
1919
matrix<base> A(n);
2020
A.read();
21-
auto blocks = frobenius_basis(A);
22-
polyn res(1);
23-
for(auto &[basis, rec]: blocks) {
24-
res *= rec;
25-
}
26-
res.print();
21+
auto blocks = frobenius_form(A);
22+
reduce(begin(blocks), end(blocks), polyn(1), multiplies{}).print();
2723
}
2824

2925
signed main() {
@@ -35,4 +31,4 @@ signed main() {
3531
while(t--) {
3632
solve();
3733
}
38-
}
34+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// @brief Pow of Matrix (Frobenius)
2+
#define PROBLEM "https://judge.yosupo.jp/problem/pow_of_matrix"
3+
#pragma GCC optimize("Ofast,unroll-loops")
4+
#pragma GCC target("avx2,tune=native")
5+
#include "cp-algo/linalg/frobenius.hpp"
6+
#include <bits/stdc++.h>
7+
8+
using namespace std;
9+
using namespace cp_algo::algebra;
10+
using namespace cp_algo::linalg;
11+
12+
const int mod = 998244353;
13+
using base = modular<mod>;
14+
using polyn = poly_t<base>;
15+
16+
template<typename base>
17+
auto frobenius_pow(matrix<base> A, uint64_t k) {
18+
using polyn = poly_t<base>;
19+
auto [T, Tinv, charps] = frobenius_form<full>(A);
20+
vector<matrix<base>> blocks;
21+
for(auto charp: charps) {
22+
matrix<base> block(charp.deg());
23+
auto xk = polyn::xk(1).powmod(k, charp);
24+
for(size_t i = 0; i < block.n(); i++) {
25+
ranges::copy(xk.a, begin(block[i]));
26+
xk = xk.mul_xk(1) % charp;
27+
}
28+
blocks.push_back(block);
29+
}
30+
auto S = matrix<base>::block_diagonal(blocks);
31+
return Tinv * S * T;
32+
}
33+
34+
void solve() {
35+
size_t n;
36+
uint64_t k;
37+
cin >> n >> k;
38+
matrix<base> A(n);
39+
A.read();
40+
frobenius_pow(A, k).print();
41+
}
42+
43+
signed main() {
44+
//freopen("input.txt", "r", stdin);
45+
ios::sync_with_stdio(0);
46+
cin.tie(0);
47+
int t = 1;
48+
//cin >> t;
49+
while(t--) {
50+
solve();
51+
}
52+
}

0 commit comments

Comments
 (0)