Skip to content

Commit 6ee273e

Browse files
Kolarusimeonschaub
andauthored
Implement to_vec for matrix factorizations (#128)
* Implement to_vec for factorizations * Fix inferrence for svd * Fix problem from qr(X).T undef values * Fix tests * Reduce number of tests Co-authored-by: Simeon Schaub <[email protected]> Co-authored-by: Simeon Schaub <[email protected]>
1 parent 1342e42 commit 6ee273e

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

src/to_vec.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,51 @@ function to_vec(X::T) where {T<:PermutedDimsArray}
156156
return x_vec, PermutedDimsArray_from_vec
157157
end
158158

159+
# Factorizations
160+
161+
function to_vec(x::F) where {F <: SVD}
162+
# Convert the vector S to a matrix so we can work with a vector of matrices
163+
# only and inferrence work
164+
v = [x.U, reshape(x.S, length(x.S), 1), x.Vt]
165+
x_vec, back = to_vec(v)
166+
function SVD_from_vec(v)
167+
U, Smat, Vt = back(v)
168+
return F(U, vec(Smat), Vt)
169+
end
170+
return x_vec, SVD_from_vec
171+
end
172+
173+
function to_vec(x::Cholesky)
174+
x_vec, back = to_vec(x.factors)
175+
function Cholesky_from_vec(v)
176+
return Cholesky(back(v), x.uplo, x.info)
177+
end
178+
return x_vec, Cholesky_from_vec
179+
end
180+
181+
function to_vec(x::S) where {U, S <: Union{LinearAlgebra.QRCompactWYQ{U}, LinearAlgebra.QRCompactWY{U}}}
182+
# x.T is composed of upper triangular blocks. The subdiagonals elements
183+
# of the blocks are abitrary. We make sure to set all of them to zero
184+
# to avoid NaN.
185+
blocksize, cols = size(x.T)
186+
T = zeros(U, blocksize, cols)
187+
188+
for i in 0:div(cols - 1, blocksize)
189+
used_cols = i * blocksize
190+
n = min(blocksize, cols - used_cols)
191+
T[1:n, (1:n) .+ used_cols] = UpperTriangular(view(x.T, 1:n, (1:n) .+ used_cols))
192+
end
193+
194+
x_vec, back = to_vec([x.factors, T])
195+
196+
function QRCompact_from_vec(v)
197+
factors, Tback = back(v)
198+
return S(factors, Tback)
199+
end
200+
201+
return x_vec, QRCompact_from_vec
202+
end
203+
159204
# Non-array data structures
160205

161206
function to_vec(x::Tuple)

test/to_vec.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ function test_to_vec(x::T; check_inferred=true) where {T}
6161
x_vec, back = to_vec(x)
6262
@test x_vec isa Vector
6363
@test all(s -> s isa Real, x_vec)
64+
@test all(!isnan, x_vec)
6465
check_inferred && @inferred back(x_vec)
6566
@test x == back(x_vec)
6667
return nothing
@@ -128,6 +129,47 @@ end
128129
)
129130
end
130131

132+
@testset "Factorizations" begin
133+
# (100, 100) is needed to test for the NaNs that can appear in the
134+
# qr(M).T matrix
135+
for dims in [(7, 3), (100, 100)]
136+
M = randn(T, dims...)
137+
P = M * M' + I # Positive definite matrix
138+
test_to_vec(svd(M))
139+
test_to_vec(cholesky(P))
140+
141+
# Special treatment for QR since it is represented by a matrix
142+
# with some arbirtrary values.
143+
F = qr(M)
144+
@inferred to_vec(F)
145+
F_vec, back = to_vec(F)
146+
@test F_vec isa Vector
147+
@test all(s -> s isa Real, F_vec)
148+
@test all(!isnan, F_vec)
149+
@inferred back(F_vec)
150+
F_back = back(F_vec)
151+
@test F_back.Q == F.Q
152+
@test F_back.R == F.R
153+
154+
# Make sure the result is consistent despite the arbitrary
155+
# values in F.T.
156+
@test first(to_vec(F)) == first(to_vec(F))
157+
158+
# Test F.Q as well since it has a special type. Since it is
159+
# represented by the same T and factors matrices than F
160+
# it needs the same special treatment.
161+
Q = F.Q
162+
@inferred to_vec(Q)
163+
Q_vec, back = to_vec(Q)
164+
@test Q_vec isa Vector
165+
@test all(s -> s isa Real, Q_vec)
166+
@test all(!isnan, Q_vec)
167+
@inferred back(Q_vec)
168+
Q_back = back(Q_vec)
169+
@test Q_back == Q
170+
end
171+
end
172+
131173
@testset "Tuples" begin
132174
test_to_vec((5, 4))
133175
test_to_vec((5, randn(T, 5)); check_inferred = VERSION v"1.2") # broken on Julia 1.6.0, fixed on 1.6.1

0 commit comments

Comments
 (0)