From 953e4e3310b66bf574061876d3de335498ac870d Mon Sep 17 00:00:00 2001 From: cgarling Date: Mon, 11 Aug 2025 17:59:18 -0400 Subject: [PATCH 1/9] Add solve methods for rectangular `A` --- src/StaticArrays.jl | 2 +- src/solve.jl | 30 +++++++++++++++++++++++------- test/solve.jl | 8 +++++++- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 023c45f57..16fb4ae03 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -120,6 +120,7 @@ include("matrix_multiply.jl") include("lu.jl") include("det.jl") include("inv.jl") +include("qr.jl") include("solve.jl") include("eigen.jl") include("expm.jl") @@ -128,7 +129,6 @@ include("lyap.jl") include("triangular.jl") include("cholesky.jl") include("svd.jl") -include("qr.jl") include("deque.jl") include("flatten.jl") include("io.jl") diff --git a/src/solve.jl b/src/solve.jl index fd9b351b8..2bddcb074 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,4 +1,5 @@ @inline (\)(a::StaticMatrix, b::StaticVecOrMat) = _solve(Size(a), Size(b), a, b) +@inline (\)(Q::QR, b::StaticVecOrMat) = Q.R \ (Q.Q' * b) @inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} @inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1]) @@ -55,15 +56,30 @@ end throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)")) end end - if prod(Sa) ≤ 14*14 && Sa[1] == Sa[2] + if prod(Sa) ≤ 14*14 # TODO: Consider triangular special cases as in Base? - quote - @_inline_meta - LUp = lu(a) - LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p]))) + if Sa[1] == Sa[2] + quote + @_inline_meta + LUp = lu(a) + LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p]))) + end + else + + quote + @_inline_meta + q = qr(a) + y = q.Q' * b + if Sa[1] > Sa[2] + R₁ = SMatrix{Sa[2], Sa[2]}(q.R[SOneTo(Sa[2]), SOneTo(Sa[2])]) + # return inv(R₁) * y + return R₁ \ y + else + return q.R' * ((q.R * q.R') \ y) + # return pinv(q.R) * y + end + end end - # TODO: Could also use static QR here if `a` is nonsquare. - # Requires that we implement \(::StaticArrays.QR,::StaticVecOrMat) else # Fall back to LinearAlgebra, but carry across the statically known size. outsize = length(Sb) == 1 ? Size(Sa[2]) : Size(Sa[2],Sb[end]) diff --git a/test/solve.jl b/test/solve.jl index 71e102257..45b92b337 100644 --- a/test/solve.jl +++ b/test/solve.jl @@ -21,11 +21,17 @@ using StaticArrays, Test, LinearAlgebra # So try all of these @testset "Mixed static/dynamic" begin v = @SVector([0.2,0.3]) + # Square matrices for m in (@SMatrix([1.0 0; 0 1.0]), @SMatrix([1.0 0; 1.0 1.0]), @SMatrix([1.0 1.0; 0 1.0]), @SMatrix([1.0 0.5; 0.25 1.0])) - # TODO: include @SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), need qr methods @test m \ v ≈ Array(m) \ v ≈ m \ Array(v) ≈ Array(m) \ Array(v) end + # Rectangular matrices + for m in (@SMatrix([1.0 0.0 0.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 0.0 0.0 1.0]), + @SMatrix([0.0 0.0 1.0; 1.0 2.0 0.5]), @SMatrix([1.0 2.0 0.5; 1.0 0.0 0.0])) + @test m \ v ≈ Array(m) \ v ≈ Array(m) \ Array(v) + @test_throws MethodError m \ Array(v) # TODO: requires adjoint(::QR) method + end end end From c348fa5c24482321a2a67816c9583308ab0bcba5 Mon Sep 17 00:00:00 2001 From: cgarling Date: Mon, 11 Aug 2025 18:00:13 -0400 Subject: [PATCH 2/9] Remove alternative implementations --- src/solve.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 2bddcb074..302b09c2b 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -72,11 +72,9 @@ end y = q.Q' * b if Sa[1] > Sa[2] R₁ = SMatrix{Sa[2], Sa[2]}(q.R[SOneTo(Sa[2]), SOneTo(Sa[2])]) - # return inv(R₁) * y - return R₁ \ y + R₁ \ y else - return q.R' * ((q.R * q.R') \ y) - # return pinv(q.R) * y + q.R' * ((q.R * q.R') \ y) end end end From 2b35006c8ee527a89592d27c58f915517a24ac3a Mon Sep 17 00:00:00 2001 From: cgarling Date: Mon, 11 Aug 2025 18:20:03 -0400 Subject: [PATCH 3/9] remove stray newline --- src/solve.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/solve.jl b/src/solve.jl index 302b09c2b..9d54a90ff 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -65,7 +65,6 @@ end LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p]))) end else - quote @_inline_meta q = qr(a) From 1379f1c4f3d78db511dac1dd19bcf72e68f4c0be Mon Sep 17 00:00:00 2001 From: cgarling Date: Tue, 12 Aug 2025 09:51:45 -0400 Subject: [PATCH 4/9] Move `::QR` solving logic out of `_solve_general` --- src/solve.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 9d54a90ff..11749d8e5 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,5 +1,19 @@ @inline (\)(a::StaticMatrix, b::StaticVecOrMat) = _solve(Size(a), Size(b), a, b) -@inline (\)(Q::QR, b::StaticVecOrMat) = Q.R \ (Q.Q' * b) +@inline (\)(q::QR, b::StaticVecOrMat) = _solve(Size(q.Q), Size(q.R), q, b) +@inline function _solve(::Size{Sq}, ::Size{Sr}, q::QR, b::StaticVecOrMat) where {Sq, Sr} + Sa = (Sq[1], Sr[2]) # Size of the original matrix: Q * R + Q, R = q.Q, q.R + if Sa[1] == Sa[2] + return R \ (Q' * b) + elseif Sa[1] > Sa[2] + y = Q' * b + R₁ = @view R[SOneTo(Sa[2]), SOneTo(Sa[2])] + return R₁ \ y + else + y = Q' * b + return R' * ((R * R') \ y) + end +end @inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} @inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1]) @@ -68,13 +82,7 @@ end quote @_inline_meta q = qr(a) - y = q.Q' * b - if Sa[1] > Sa[2] - R₁ = SMatrix{Sa[2], Sa[2]}(q.R[SOneTo(Sa[2]), SOneTo(Sa[2])]) - R₁ \ y - else - q.R' * ((q.R * q.R') \ y) - end + q \ b end end else From 71b4ab3bb63afba3b7346a57d93736d9a790d7e4 Mon Sep 17 00:00:00 2001 From: cgarling Date: Tue, 12 Aug 2025 10:06:50 -0400 Subject: [PATCH 5/9] Use `UpperTriangular` when possible --- src/solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 11749d8e5..dd6e7c1a4 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -4,10 +4,10 @@ Sa = (Sq[1], Sr[2]) # Size of the original matrix: Q * R Q, R = q.Q, q.R if Sa[1] == Sa[2] - return R \ (Q' * b) + return UpperTriangular(R) \ (Q' * b) elseif Sa[1] > Sa[2] y = Q' * b - R₁ = @view R[SOneTo(Sa[2]), SOneTo(Sa[2])] + R₁ = UpperTriangular(@view R[SOneTo(Sa[2]), SOneTo(Sa[2])]) return R₁ \ y else y = Q' * b From fc315c3a0c0b779dc4d4306439931bfdceda3411 Mon Sep 17 00:00:00 2001 From: cgarling Date: Tue, 12 Aug 2025 10:11:01 -0400 Subject: [PATCH 6/9] hoist `y` out of branches --- src/solve.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index dd6e7c1a4..24995c395 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -3,14 +3,13 @@ @inline function _solve(::Size{Sq}, ::Size{Sr}, q::QR, b::StaticVecOrMat) where {Sq, Sr} Sa = (Sq[1], Sr[2]) # Size of the original matrix: Q * R Q, R = q.Q, q.R + y = Q' * b if Sa[1] == Sa[2] - return UpperTriangular(R) \ (Q' * b) + return UpperTriangular(R) \ y elseif Sa[1] > Sa[2] - y = Q' * b R₁ = UpperTriangular(@view R[SOneTo(Sa[2]), SOneTo(Sa[2])]) return R₁ \ y else - y = Q' * b return R' * ((R * R') \ y) end end From f6fdce6ec6162dabd7a3a90d5542d07f0bd3ed62 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 13 Aug 2025 10:55:13 +0200 Subject: [PATCH 7/9] add ported code for wide matrix solve with qr; add more tests; use pivoted QR by default in solve --- src/solve.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++++++-- test/solve.jl | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 24995c395..be716c77d 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -10,10 +10,66 @@ R₁ = UpperTriangular(@view R[SOneTo(Sa[2]), SOneTo(Sa[2])]) return R₁ \ y else - return R' * ((R * R') \ y) + return _wide_qr_solve(q, b) end end +# based on https://github.com/JuliaLang/LinearAlgebra.jl/blob/16f64e78769d788376df0f36447affdb7b1b3df6/src/qr.jl#L652C1-L697C4 +function _wide_qr_solve(A::QR{T}, B::StaticMatrix{mB,nB,T}) where {mB,nB,T} + m, n = size(A) + minmn = min(m, n) + Bbuffer = similar(B) + copyto!(Bbuffer, B) + lmul!(adjoint(A.Q), view(Bbuffer, 1:m, :)) + Rbuffer = similar(A.R) + copyto!(Rbuffer, A.R) + + @inbounds begin + if n > m # minimum norm solution + τ = zeros(T,m) + for k = m:-1:1 # Trapezoid to triangular by elementary operation + x = view(Rbuffer, k, [k; m + 1:n]) + τk = LinearAlgebra.reflector!(x) + τ[k] = conj(τk) + for i = 1:k - 1 + vRi = Rbuffer[i,k] + for j = m + 1:n + vRi += Rbuffer[i,j]*x[j - m + 1]' + end + vRi *= τk + Rbuffer[i,k] -= vRi + for j = m + 1:n + Rbuffer[i,j] -= vRi*x[j - m + 1] + end + end + end + end + ldiv!(UpperTriangular(view(Rbuffer, :, SOneTo(minmn))), view(Bbuffer, SOneTo(minmn), :)) + if n > m # Apply elementary transformation to solution + Bbuffer[m + 1:mB,1:nB] .= zero(T) + for j = 1:nB + for k = 1:m + vBj = Bbuffer[k,j]' + for i = m + 1:n + vBj += Bbuffer[i,j]'*Rbuffer[k,i]' + end + vBj *= τ[k] + Bbuffer[k,j] -= vBj' + for i = m + 1:n + Bbuffer[i,j] -= Rbuffer[k,i]'*vBj' + end + end + end + end + end + return similar_type(B)(Bbuffer) +end +function _wide_qr_solve(q::QR, b::StaticVecOrMat) + Q, R = q.Q, q.R + y = Q' * b + return R' * ((R * R') \ y) +end + @inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} @inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1]) end @@ -80,7 +136,7 @@ end else quote @_inline_meta - q = qr(a) + q = qr(a, ColumnNorm()) q \ b end end diff --git a/test/solve.jl b/test/solve.jl index 45b92b337..91a93c020 100644 --- a/test/solve.jl +++ b/test/solve.jl @@ -33,6 +33,48 @@ using StaticArrays, Test, LinearAlgebra @test_throws MethodError m \ Array(v) # TODO: requires adjoint(::QR) method end end + @testset "More static tests" begin + # 1) 3×5 real, two RHS + A1 = @SMatrix [1.0 2.0 3.0 4.0 5.0; + 0.0 1.0 0.0 1.0 0.0; + -1.0 0.0 2.0 -2.0 1.0] + B1 = @SMatrix [1.0 0.0; + 0.0 1.0; + 1.0 1.0] + + # 2) 4×6 real + A2 = @SMatrix [ 2.0 -1.0 0.0 4.0 1.0 3.0; + -3.0 2.0 5.0 -1.0 0.0 2.0; + 1.0 0.0 1.0 0.0 2.0 -2.0; + 0.0 3.0 -1.0 1.0 1.0 0.0] + b2_1 = @SVector [1.0, 4.0, -2.0, 0.5] + b2_2 = @SMatrix [1.0 1.0 + 4.0 6.0 + -2.0 2.0 + 0.5 1.5] + + # 3) 3×4 complex + A3 = @SMatrix [1+2im 0+1im 2-1im 3+0im; + 0+0im 2+0im 1+1im 0-2im; + 3-1im -1+0im 0+2im 1+0im] + b3_1 = @SVector [1+0im, 2-1im, -1+3im] + b3_2 = @SMatrix [ + 1+0im -9+0im + 2-1im 2-4im + -1+3im 2+3im] + + # 4) 3×6 rank-deficient (cols 3 = 1+2, col 4 = col 1, col 5 = col 2, col 6 = zeros) + A4 = @SMatrix [1.0 2.0 3.0 1.0 2.0 0.0; + 0.0 1.0 1.0 0.0 1.0 0.0; + 1.0 3.0 4.0 1.0 3.0 0.0] + b4_1 = @SVector [1.0, 0.0, 1.0] + b4_2 = @SMatrix [1.0 0.0 + 0.0 1.0 + 1.0 0.0] + for (A, B) in [(A1, B1), (A2, b2_1), (A2, b2_2), (A3, b3_1), (A3, b3_2), (A4, b4_1), (A4, b4_2)] + @test A \ B ≈ Array(A) \ Array(B) + end + end end From 8f47cd0c08d1c45c6e1a92982a0894faadd69425 Mon Sep 17 00:00:00 2001 From: cgarling Date: Wed, 13 Aug 2025 10:02:04 -0400 Subject: [PATCH 8/9] Apply permutation for pivoted QR --- src/solve.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index be716c77d..690be95f2 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -2,16 +2,18 @@ @inline (\)(q::QR, b::StaticVecOrMat) = _solve(Size(q.Q), Size(q.R), q, b) @inline function _solve(::Size{Sq}, ::Size{Sr}, q::QR, b::StaticVecOrMat) where {Sq, Sr} Sa = (Sq[1], Sr[2]) # Size of the original matrix: Q * R - Q, R = q.Q, q.R + Q, R, p = q.Q, q.R, q.p y = Q' * b - if Sa[1] == Sa[2] - return UpperTriangular(R) \ y + Z = if Sa[1] == Sa[2] + UpperTriangular(R) \ y elseif Sa[1] > Sa[2] R₁ = UpperTriangular(@view R[SOneTo(Sa[2]), SOneTo(Sa[2])]) - return R₁ \ y + R₁ \ y else - return _wide_qr_solve(q, b) + _wide_qr_solve(q, b) end + invp = invperm(p) + return @view Z[invp, :] end # based on https://github.com/JuliaLang/LinearAlgebra.jl/blob/16f64e78769d788376df0f36447affdb7b1b3df6/src/qr.jl#L652C1-L697C4 From aa648ce9730a9c37bd17fea481e70c69e97585cb Mon Sep 17 00:00:00 2001 From: cgarling Date: Wed, 13 Aug 2025 10:57:01 -0400 Subject: [PATCH 9/9] Only apply permutation if necessary --- src/solve.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index 690be95f2..ba1f0a660 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -12,8 +12,13 @@ else _wide_qr_solve(q, b) end - invp = invperm(p) - return @view Z[invp, :] + # Apply pivot permutation if necessary + return if p != SOneTo(length(p)) + invp = invperm(p) + @view Z[invp, :] + else + Z + end end # based on https://github.com/JuliaLang/LinearAlgebra.jl/blob/16f64e78769d788376df0f36447affdb7b1b3df6/src/qr.jl#L652C1-L697C4