diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 023c45f5..16fb4ae0 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -120,6 +120,7 @@ include("matrix_multiply.jl") include("lu.jl") include("det.jl") include("inv.jl") +include("qr.jl") include("solve.jl") include("eigen.jl") include("expm.jl") @@ -128,7 +129,6 @@ include("lyap.jl") include("triangular.jl") include("cholesky.jl") include("svd.jl") -include("qr.jl") include("deque.jl") include("flatten.jl") include("io.jl") diff --git a/src/solve.jl b/src/solve.jl index fd9b351b..ba1f0a66 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,4 +1,81 @@ @inline (\)(a::StaticMatrix, b::StaticVecOrMat) = _solve(Size(a), Size(b), a, b) +@inline (\)(q::QR, b::StaticVecOrMat) = _solve(Size(q.Q), Size(q.R), q, b) +@inline function _solve(::Size{Sq}, ::Size{Sr}, q::QR, b::StaticVecOrMat) where {Sq, Sr} + Sa = (Sq[1], Sr[2]) # Size of the original matrix: Q * R + Q, R, p = q.Q, q.R, q.p + y = Q' * b + Z = if Sa[1] == Sa[2] + UpperTriangular(R) \ y + elseif Sa[1] > Sa[2] + R₁ = UpperTriangular(@view R[SOneTo(Sa[2]), SOneTo(Sa[2])]) + R₁ \ y + else + _wide_qr_solve(q, b) + end + # Apply pivot permutation if necessary + return if p != SOneTo(length(p)) + invp = invperm(p) + @view Z[invp, :] + else + Z + end +end + +# based on https://github.com/JuliaLang/LinearAlgebra.jl/blob/16f64e78769d788376df0f36447affdb7b1b3df6/src/qr.jl#L652C1-L697C4 +function _wide_qr_solve(A::QR{T}, B::StaticMatrix{mB,nB,T}) where {mB,nB,T} + m, n = size(A) + minmn = min(m, n) + Bbuffer = similar(B) + copyto!(Bbuffer, B) + lmul!(adjoint(A.Q), view(Bbuffer, 1:m, :)) + Rbuffer = similar(A.R) + copyto!(Rbuffer, A.R) + + @inbounds begin + if n > m # minimum norm solution + τ = zeros(T,m) + for k = m:-1:1 # Trapezoid to triangular by elementary operation + x = view(Rbuffer, k, [k; m + 1:n]) + τk = LinearAlgebra.reflector!(x) + τ[k] = conj(τk) + for i = 1:k - 1 + vRi = Rbuffer[i,k] + for j = m + 1:n + vRi += Rbuffer[i,j]*x[j - m + 1]' + end + vRi *= τk + Rbuffer[i,k] -= vRi + for j = m + 1:n + Rbuffer[i,j] -= vRi*x[j - m + 1] + end + end + end + end + ldiv!(UpperTriangular(view(Rbuffer, :, SOneTo(minmn))), view(Bbuffer, SOneTo(minmn), :)) + if n > m # Apply elementary transformation to solution + Bbuffer[m + 1:mB,1:nB] .= zero(T) + for j = 1:nB + for k = 1:m + vBj = Bbuffer[k,j]' + for i = m + 1:n + vBj += Bbuffer[i,j]'*Rbuffer[k,i]' + end + vBj *= τ[k] + Bbuffer[k,j] -= vBj' + for i = m + 1:n + Bbuffer[i,j] -= Rbuffer[k,i]'*vBj' + end + end + end + end + end + return similar_type(B)(Bbuffer) +end +function _wide_qr_solve(q::QR, b::StaticVecOrMat) + Q, R = q.Q, q.R + y = Q' * b + return R' * ((R * R') \ y) +end @inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} @inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1]) @@ -55,15 +132,21 @@ end throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)")) end end - if prod(Sa) ≤ 14*14 && Sa[1] == Sa[2] + if prod(Sa) ≤ 14*14 # TODO: Consider triangular special cases as in Base? - quote - @_inline_meta - LUp = lu(a) - LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p]))) + if Sa[1] == Sa[2] + quote + @_inline_meta + LUp = lu(a) + LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p]))) + end + else + quote + @_inline_meta + q = qr(a, ColumnNorm()) + q \ b + end end - # TODO: Could also use static QR here if `a` is nonsquare. - # Requires that we implement \(::StaticArrays.QR,::StaticVecOrMat) else # Fall back to LinearAlgebra, but carry across the statically known size. outsize = length(Sb) == 1 ? Size(Sa[2]) : Size(Sa[2],Sb[end]) diff --git a/test/solve.jl b/test/solve.jl index 71e10225..91a93c02 100644 --- a/test/solve.jl +++ b/test/solve.jl @@ -21,11 +21,59 @@ using StaticArrays, Test, LinearAlgebra # So try all of these @testset "Mixed static/dynamic" begin v = @SVector([0.2,0.3]) + # Square matrices for m in (@SMatrix([1.0 0; 0 1.0]), @SMatrix([1.0 0; 1.0 1.0]), @SMatrix([1.0 1.0; 0 1.0]), @SMatrix([1.0 0.5; 0.25 1.0])) - # TODO: include @SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), need qr methods @test m \ v ≈ Array(m) \ v ≈ m \ Array(v) ≈ Array(m) \ Array(v) end + # Rectangular matrices + for m in (@SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 0.0 0.0 1.0]), + @SMatrix([0.0 0.0 1.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 1.0 0.0 0.0])) + @test m \ v ≈ Array(m) \ v ≈ Array(m) \ Array(v) + @test_throws MethodError m \ Array(v) # TODO: requires adjoint(::QR) method + end + end + @testset "More static tests" begin + # 1) 3×5 real, two RHS + A1 = @SMatrix [1.0 2.0 3.0 4.0 5.0; + 0.0 1.0 0.0 1.0 0.0; + -1.0 0.0 2.0 -2.0 1.0] + B1 = @SMatrix [1.0 0.0; + 0.0 1.0; + 1.0 1.0] + + # 2) 4×6 real + A2 = @SMatrix [ 2.0 -1.0 0.0 4.0 1.0 3.0; + -3.0 2.0 5.0 -1.0 0.0 2.0; + 1.0 0.0 1.0 0.0 2.0 -2.0; + 0.0 3.0 -1.0 1.0 1.0 0.0] + b2_1 = @SVector [1.0, 4.0, -2.0, 0.5] + b2_2 = @SMatrix [1.0 1.0 + 4.0 6.0 + -2.0 2.0 + 0.5 1.5] + + # 3) 3×4 complex + A3 = @SMatrix [1+2im 0+1im 2-1im 3+0im; + 0+0im 2+0im 1+1im 0-2im; + 3-1im -1+0im 0+2im 1+0im] + b3_1 = @SVector [1+0im, 2-1im, -1+3im] + b3_2 = @SMatrix [ + 1+0im -9+0im + 2-1im 2-4im + -1+3im 2+3im] + + # 4) 3×6 rank-deficient (cols 3 = 1+2, col 4 = col 1, col 5 = col 2, col 6 = zeros) + A4 = @SMatrix [1.0 2.0 3.0 1.0 2.0 0.0; + 0.0 1.0 1.0 0.0 1.0 0.0; + 1.0 3.0 4.0 1.0 3.0 0.0] + b4_1 = @SVector [1.0, 0.0, 1.0] + b4_2 = @SMatrix [1.0 0.0 + 0.0 1.0 + 1.0 0.0] + for (A, B) in [(A1, B1), (A2, b2_1), (A2, b2_2), (A3, b3_1), (A3, b3_2), (A4, b4_1), (A4, b4_2)] + @test A \ B ≈ Array(A) \ Array(B) + end end end