Skip to content

Commit f6fdce6

Browse files
committed
add ported code for wide matrix solve with qr; add more tests; use pivoted QR by default in solve
1 parent fc315c3 commit f6fdce6

File tree

2 files changed

+100
-2
lines changed

2 files changed

+100
-2
lines changed

src/solve.jl

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,66 @@
1010
R₁ = UpperTriangular(@view R[SOneTo(Sa[2]), SOneTo(Sa[2])])
1111
return R₁ \ y
1212
else
13-
return R' * ((R * R') \ y)
13+
return _wide_qr_solve(q, b)
1414
end
1515
end
1616

17+
# based on https://github.com/JuliaLang/LinearAlgebra.jl/blob/16f64e78769d788376df0f36447affdb7b1b3df6/src/qr.jl#L652C1-L697C4
18+
function _wide_qr_solve(A::QR{T}, B::StaticMatrix{mB,nB,T}) where {mB,nB,T}
19+
m, n = size(A)
20+
minmn = min(m, n)
21+
Bbuffer = similar(B)
22+
copyto!(Bbuffer, B)
23+
lmul!(adjoint(A.Q), view(Bbuffer, 1:m, :))
24+
Rbuffer = similar(A.R)
25+
copyto!(Rbuffer, A.R)
26+
27+
@inbounds begin
28+
if n > m # minimum norm solution
29+
τ = zeros(T,m)
30+
for k = m:-1:1 # Trapezoid to triangular by elementary operation
31+
x = view(Rbuffer, k, [k; m + 1:n])
32+
τk = LinearAlgebra.reflector!(x)
33+
τ[k] = conj(τk)
34+
for i = 1:k - 1
35+
vRi = Rbuffer[i,k]
36+
for j = m + 1:n
37+
vRi += Rbuffer[i,j]*x[j - m + 1]'
38+
end
39+
vRi *= τk
40+
Rbuffer[i,k] -= vRi
41+
for j = m + 1:n
42+
Rbuffer[i,j] -= vRi*x[j - m + 1]
43+
end
44+
end
45+
end
46+
end
47+
ldiv!(UpperTriangular(view(Rbuffer, :, SOneTo(minmn))), view(Bbuffer, SOneTo(minmn), :))
48+
if n > m # Apply elementary transformation to solution
49+
Bbuffer[m + 1:mB,1:nB] .= zero(T)
50+
for j = 1:nB
51+
for k = 1:m
52+
vBj = Bbuffer[k,j]'
53+
for i = m + 1:n
54+
vBj += Bbuffer[i,j]'*Rbuffer[k,i]'
55+
end
56+
vBj *= τ[k]
57+
Bbuffer[k,j] -= vBj'
58+
for i = m + 1:n
59+
Bbuffer[i,j] -= Rbuffer[k,i]'*vBj'
60+
end
61+
end
62+
end
63+
end
64+
end
65+
return similar_type(B)(Bbuffer)
66+
end
67+
function _wide_qr_solve(q::QR, b::StaticVecOrMat)
68+
Q, R = q.Q, q.R
69+
y = Q' * b
70+
return R' * ((R * R') \ y)
71+
end
72+
1773
@inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
1874
@inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1])
1975
end
@@ -80,7 +136,7 @@ end
80136
else
81137
quote
82138
@_inline_meta
83-
q = qr(a)
139+
q = qr(a, ColumnNorm())
84140
q \ b
85141
end
86142
end

test/solve.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,48 @@ using StaticArrays, Test, LinearAlgebra
3333
@test_throws MethodError m \ Array(v) # TODO: requires adjoint(::QR) method
3434
end
3535
end
36+
@testset "More static tests" begin
37+
# 1) 3×5 real, two RHS
38+
A1 = @SMatrix [1.0 2.0 3.0 4.0 5.0;
39+
0.0 1.0 0.0 1.0 0.0;
40+
-1.0 0.0 2.0 -2.0 1.0]
41+
B1 = @SMatrix [1.0 0.0;
42+
0.0 1.0;
43+
1.0 1.0]
44+
45+
# 2) 4×6 real
46+
A2 = @SMatrix [ 2.0 -1.0 0.0 4.0 1.0 3.0;
47+
-3.0 2.0 5.0 -1.0 0.0 2.0;
48+
1.0 0.0 1.0 0.0 2.0 -2.0;
49+
0.0 3.0 -1.0 1.0 1.0 0.0]
50+
b2_1 = @SVector [1.0, 4.0, -2.0, 0.5]
51+
b2_2 = @SMatrix [1.0 1.0
52+
4.0 6.0
53+
-2.0 2.0
54+
0.5 1.5]
55+
56+
# 3) 3×4 complex
57+
A3 = @SMatrix [1+2im 0+1im 2-1im 3+0im;
58+
0+0im 2+0im 1+1im 0-2im;
59+
3-1im -1+0im 0+2im 1+0im]
60+
b3_1 = @SVector [1+0im, 2-1im, -1+3im]
61+
b3_2 = @SMatrix [
62+
1+0im -9+0im
63+
2-1im 2-4im
64+
-1+3im 2+3im]
65+
66+
# 4) 3×6 rank-deficient (cols 3 = 1+2, col 4 = col 1, col 5 = col 2, col 6 = zeros)
67+
A4 = @SMatrix [1.0 2.0 3.0 1.0 2.0 0.0;
68+
0.0 1.0 1.0 0.0 1.0 0.0;
69+
1.0 3.0 4.0 1.0 3.0 0.0]
70+
b4_1 = @SVector [1.0, 0.0, 1.0]
71+
b4_2 = @SMatrix [1.0 0.0
72+
0.0 1.0
73+
1.0 0.0]
74+
for (A, B) in [(A1, B1), (A2, b2_1), (A2, b2_2), (A3, b3_1), (A3, b3_2), (A4, b4_1), (A4, b4_2)]
75+
@test A \ B Array(A) \ Array(B)
76+
end
77+
end
3678
end
3779

3880

0 commit comments

Comments
 (0)