Skip to content

Implement solve (\) for general rectangular static matrices #1313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ include("matrix_multiply.jl")
include("lu.jl")
include("det.jl")
include("inv.jl")
include("qr.jl")
include("solve.jl")
include("eigen.jl")
include("expm.jl")
Expand All @@ -128,7 +129,6 @@ include("lyap.jl")
include("triangular.jl")
include("cholesky.jl")
include("svd.jl")
include("qr.jl")
include("deque.jl")
include("flatten.jl")
include("io.jl")
Expand Down
27 changes: 20 additions & 7 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@inline (\)(a::StaticMatrix, b::StaticVecOrMat) = _solve(Size(a), Size(b), a, b)
@inline (\)(Q::QR, b::StaticVecOrMat) = Q.R \ (Q.Q' * b)

@inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
@inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1])
Expand Down Expand Up @@ -55,15 +56,27 @@ end
throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)"))
end
end
if prod(Sa) ≤ 14*14 && Sa[1] == Sa[2]
if prod(Sa) ≤ 14*14
# TODO: Consider triangular special cases as in Base?
quote
@_inline_meta
LUp = lu(a)
LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p])))
if Sa[1] == Sa[2]
quote
@_inline_meta
LUp = lu(a)
LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p])))
end
else
quote
@_inline_meta
q = qr(a)
y = q.Q' * b
if Sa[1] > Sa[2]
R₁ = SMatrix{Sa[2], Sa[2]}(q.R[SOneTo(Sa[2]), SOneTo(Sa[2])])
R₁ \ y
else
q.R' * ((q.R * q.R') \ y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
q.R' * ((q.R * q.R') \ y)
q.R' * (q.R' \ (q.R \ y))

should be a lot faster, if q.R is annotated as UpperTriangular (you could wrap it if not).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, couple questions:

q.R has dimensions (Sa[1], Sa[2]) with Sa[1] < Sa[2]. I don't see how to use UpperTriangular on a non-square matrix as, for example, UpperTriangular(rand(2,3)) gives ERROR: DimensionMismatch: matrix is not square: dimensions are (2, 3). Am I missing something?

Using the expression q.R' * (q.R' \ (q.R \ y)) you provided results in a stack overflow error, I believe because of the (q.R \ y) which will recurse infinitely because q.R has the same dimensions as argument a and y has the same dimensions as argument b. Any suggestions to improve this are welcome

end
end
end
# TODO: Could also use static QR here if `a` is nonsquare.
# Requires that we implement \(::StaticArrays.QR,::StaticVecOrMat)
else
# Fall back to LinearAlgebra, but carry across the statically known size.
outsize = length(Sb) == 1 ? Size(Sa[2]) : Size(Sa[2],Sb[end])
Expand Down
8 changes: 7 additions & 1 deletion test/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ using StaticArrays, Test, LinearAlgebra
# So try all of these
@testset "Mixed static/dynamic" begin
v = @SVector([0.2,0.3])
# Square matrices
for m in (@SMatrix([1.0 0; 0 1.0]), @SMatrix([1.0 0; 1.0 1.0]),
@SMatrix([1.0 1.0; 0 1.0]), @SMatrix([1.0 0.5; 0.25 1.0]))
# TODO: include @SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), need qr methods
@test m \ v ≈ Array(m) \ v ≈ m \ Array(v) ≈ Array(m) \ Array(v)
end
# Rectangular matrices
for m in (@SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 0.0 0.0 1.0]),
@SMatrix([0.0 0.0 1.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 1.0 0.0 0.0]))
@test m \ v ≈ Array(m) \ v ≈ Array(m) \ Array(v)
@test_throws MethodError m \ Array(v) # TODO: requires adjoint(::QR) method
end
end
end

Expand Down
Loading