Skip to content

Implement solve (\) for general rectangular static matrices #1313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ include("matrix_multiply.jl")
include("lu.jl")
include("det.jl")
include("inv.jl")
include("qr.jl")
include("solve.jl")
include("eigen.jl")
include("expm.jl")
Expand All @@ -128,7 +129,6 @@ include("lyap.jl")
include("triangular.jl")
include("cholesky.jl")
include("svd.jl")
include("qr.jl")
include("deque.jl")
include("flatten.jl")
include("io.jl")
Expand Down
97 changes: 90 additions & 7 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,81 @@
@inline (\)(a::StaticMatrix, b::StaticVecOrMat) = _solve(Size(a), Size(b), a, b)
@inline (\)(q::QR, b::StaticVecOrMat) = _solve(Size(q.Q), Size(q.R), q, b)
@inline function _solve(::Size{Sq}, ::Size{Sr}, q::QR, b::StaticVecOrMat) where {Sq, Sr}
Sa = (Sq[1], Sr[2]) # Size of the original matrix: Q * R
Q, R, p = q.Q, q.R, q.p
y = Q' * b
Z = if Sa[1] == Sa[2]
UpperTriangular(R) \ y
elseif Sa[1] > Sa[2]
R₁ = UpperTriangular(@view R[SOneTo(Sa[2]), SOneTo(Sa[2])])
R₁ \ y
else
_wide_qr_solve(q, b)
end
# Apply pivot permutation if necessary
return if p != SOneTo(length(p))
invp = invperm(p)
@view Z[invp, :]
else
Z
end
end

# based on https://github.com/JuliaLang/LinearAlgebra.jl/blob/16f64e78769d788376df0f36447affdb7b1b3df6/src/qr.jl#L652C1-L697C4
function _wide_qr_solve(A::QR{T}, B::StaticMatrix{mB,nB,T}) where {mB,nB,T}
m, n = size(A)
minmn = min(m, n)
Bbuffer = similar(B)
copyto!(Bbuffer, B)
lmul!(adjoint(A.Q), view(Bbuffer, 1:m, :))
Rbuffer = similar(A.R)
copyto!(Rbuffer, A.R)

@inbounds begin
if n > m # minimum norm solution
τ = zeros(T,m)
for k = m:-1:1 # Trapezoid to triangular by elementary operation
x = view(Rbuffer, k, [k; m + 1:n])
τk = LinearAlgebra.reflector!(x)
τ[k] = conj(τk)
for i = 1:k - 1
vRi = Rbuffer[i,k]
for j = m + 1:n
vRi += Rbuffer[i,j]*x[j - m + 1]'
end
vRi *= τk
Rbuffer[i,k] -= vRi
for j = m + 1:n
Rbuffer[i,j] -= vRi*x[j - m + 1]
end
end
end
end
ldiv!(UpperTriangular(view(Rbuffer, :, SOneTo(minmn))), view(Bbuffer, SOneTo(minmn), :))
if n > m # Apply elementary transformation to solution
Bbuffer[m + 1:mB,1:nB] .= zero(T)
for j = 1:nB
for k = 1:m
vBj = Bbuffer[k,j]'
for i = m + 1:n
vBj += Bbuffer[i,j]'*Rbuffer[k,i]'
end
vBj *= τ[k]
Bbuffer[k,j] -= vBj'
for i = m + 1:n
Bbuffer[i,j] -= Rbuffer[k,i]'*vBj'
end
end
end
end
end
return similar_type(B)(Bbuffer)
end
function _wide_qr_solve(q::QR, b::StaticVecOrMat)
Q, R = q.Q, q.R
y = Q' * b
return R' * ((R * R') \ y)
end

@inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
@inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1])
Expand Down Expand Up @@ -55,15 +132,21 @@ end
throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)"))
end
end
if prod(Sa) ≤ 14*14 && Sa[1] == Sa[2]
if prod(Sa) ≤ 14*14
# TODO: Consider triangular special cases as in Base?
quote
@_inline_meta
LUp = lu(a)
LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p])))
if Sa[1] == Sa[2]
quote
@_inline_meta
LUp = lu(a)
LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p])))
end
else
quote
@_inline_meta
q = qr(a, ColumnNorm())
q \ b
end
end
# TODO: Could also use static QR here if `a` is nonsquare.
# Requires that we implement \(::StaticArrays.QR,::StaticVecOrMat)
else
# Fall back to LinearAlgebra, but carry across the statically known size.
outsize = length(Sb) == 1 ? Size(Sa[2]) : Size(Sa[2],Sb[end])
Expand Down
50 changes: 49 additions & 1 deletion test/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,59 @@ using StaticArrays, Test, LinearAlgebra
# So try all of these
@testset "Mixed static/dynamic" begin
v = @SVector([0.2,0.3])
# Square matrices
for m in (@SMatrix([1.0 0; 0 1.0]), @SMatrix([1.0 0; 1.0 1.0]),
@SMatrix([1.0 1.0; 0 1.0]), @SMatrix([1.0 0.5; 0.25 1.0]))
# TODO: include @SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), need qr methods
@test m \ v ≈ Array(m) \ v ≈ m \ Array(v) ≈ Array(m) \ Array(v)
end
# Rectangular matrices
for m in (@SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 0.0 0.0 1.0]),
@SMatrix([0.0 0.0 1.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 1.0 0.0 0.0]))
@test m \ v ≈ Array(m) \ v ≈ Array(m) \ Array(v)
@test_throws MethodError m \ Array(v) # TODO: requires adjoint(::QR) method
end
end
@testset "More static tests" begin
# 1) 3×5 real, two RHS
A1 = @SMatrix [1.0 2.0 3.0 4.0 5.0;
0.0 1.0 0.0 1.0 0.0;
-1.0 0.0 2.0 -2.0 1.0]
B1 = @SMatrix [1.0 0.0;
0.0 1.0;
1.0 1.0]

# 2) 4×6 real
A2 = @SMatrix [ 2.0 -1.0 0.0 4.0 1.0 3.0;
-3.0 2.0 5.0 -1.0 0.0 2.0;
1.0 0.0 1.0 0.0 2.0 -2.0;
0.0 3.0 -1.0 1.0 1.0 0.0]
b2_1 = @SVector [1.0, 4.0, -2.0, 0.5]
b2_2 = @SMatrix [1.0 1.0
4.0 6.0
-2.0 2.0
0.5 1.5]

# 3) 3×4 complex
A3 = @SMatrix [1+2im 0+1im 2-1im 3+0im;
0+0im 2+0im 1+1im 0-2im;
3-1im -1+0im 0+2im 1+0im]
b3_1 = @SVector [1+0im, 2-1im, -1+3im]
b3_2 = @SMatrix [
1+0im -9+0im
2-1im 2-4im
-1+3im 2+3im]

# 4) 3×6 rank-deficient (cols 3 = 1+2, col 4 = col 1, col 5 = col 2, col 6 = zeros)
A4 = @SMatrix [1.0 2.0 3.0 1.0 2.0 0.0;
0.0 1.0 1.0 0.0 1.0 0.0;
1.0 3.0 4.0 1.0 3.0 0.0]
b4_1 = @SVector [1.0, 0.0, 1.0]
b4_2 = @SMatrix [1.0 0.0
0.0 1.0
1.0 0.0]
for (A, B) in [(A1, B1), (A2, b2_1), (A2, b2_2), (A3, b3_1), (A3, b3_2), (A4, b4_1), (A4, b4_2)]
@test A \ B ≈ Array(A) \ Array(B)
end
end
end

Expand Down
Loading