Skip to content

Commit efda924

Browse files
Merge pull request #683 from SciML/myb/linsol
Better LU factorization and linear solve
2 parents ce3c3e4 + dce4f9e commit efda924

File tree

3 files changed

+87
-56
lines changed

3 files changed

+87
-56
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ Base.convert(::Type{<:Array{Num}}, x::AbstractArray) = map(Num, x)
123123
Base.convert(::Type{<:Array{Num}}, x::AbstractArray{Num}) = x
124124
Base.convert(::Type{Sym}, x::Num) = value(x) isa Sym ? value(x) : error("cannot convert $x to Sym")
125125

126-
LinearAlgebra.lu(x::Array{Num}; kw...) = lu(x, Val{false}(); kw...)
126+
LinearAlgebra.lu(x::Array{Num}; check=true, kw...) = sym_lu(x; check=check)
127127

128128
"""
129129
$(TYPEDEF)

src/solve.jl

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,46 @@ function nterms(t)
99
end
1010
# Soft pivoted
1111
# Note: we call this function with a matrix of Union{SymbolicUtils.Symbolic, Any}
12-
# It should work as-is with Operation type too.
13-
function sym_lu(A)
12+
function sym_lu(A; check=true)
13+
SINGULAR = typemax(Int)
1414
m, n = size(A)
15-
L = fill!(Array{Any}(undef, size(A)),0) # TODO: make sparse?
16-
for i=1:min(m, n)
17-
L[i,i] = 1
18-
end
19-
U = copy!(Array{Any}(undef, size(A)),A)
20-
p = BlasInt[1:m;]
21-
for k = 1:m-1
22-
_, i = findmin(map(ii->_iszero(U[ii, k]) ? Inf : nterms(U[ii,k]), k:n))
23-
i += k - 1
15+
F = map(x->x isa Num ? x : Num(x), A)
16+
minmn = min(m, n)
17+
p = Vector{BlasInt}(undef, minmn)
18+
info = 0
19+
for k = 1:minmn
20+
kp = k
21+
amin = SINGULAR
22+
for i in k:m
23+
absi = _iszero(F[i, k]) ? SINGULAR : nterms(F[i,k])
24+
if absi < amin
25+
kp = i
26+
amin = absi
27+
end
28+
end
29+
30+
p[k] = kp
31+
32+
if amin == SINGULAR && !(amin isa Symbolic) && (amin isa Number) && iszero(info)
33+
info = k
34+
end
35+
2436
# swap
25-
U[k, k:end], U[i, k:end] = U[i, k:end], U[k, k:end]
26-
L[k, 1:k-1], L[i, 1:k-1] = L[i, 1:k-1], L[k, 1:k-1]
27-
p[k] = i
37+
for j in 1:n
38+
F[k, j], F[kp, j] = F[kp, j], F[k, j]
39+
end
2840

29-
for j = k+1:m
30-
L[j,k] = U[j, k] / U[k, k]
31-
U[j,k:m] .= U[j,k:m] .- (L[j,k],) .* U[k,k:m]
41+
for i in k+1:m
42+
F[i, k] = F[i, k] / F[k, k]
3243
end
33-
end
34-
for j=1:m
35-
for i=j+1:n
36-
U[i,j] = 0
44+
for j = k+1:n
45+
for i in k+1:m
46+
F[i, j] = F[i, j] - F[i, k] * F[k, j]
47+
end
3748
end
3849
end
39-
40-
(L, U, LinearAlgebra.ipiv2perm(p, m))
50+
check && LinearAlgebra.checknonsingular(info, Val{true}())
51+
LU(F, p, convert(BlasInt, info))
4152
end
4253

4354
# Given a vector of equations and a
@@ -77,46 +88,56 @@ function solve_for(eqs, vars)
7788
end
7889

7990
function _solve(A::AbstractMatrix, b::AbstractArray)
80-
A = SymbolicUtils.simplify.(to_symbolic.(A), polynorm=true)
81-
b = SymbolicUtils.simplify.(to_symbolic.(b), polynorm=true)
82-
SymbolicUtils.simplify.(ldiv(sym_lu(A), b))
91+
A = SymbolicUtils.simplify.(Num.(A), polynorm=true)
92+
b = SymbolicUtils.simplify.(Num.(b), polynorm=true)
93+
value.(SymbolicUtils.simplify.(sym_lu(A) \ b))
8394
end
8495
_solve(a, b) = value(SymbolicUtils.simplify(b/a, polynorm=true))
8596

8697
# ldiv below
8798

8899
_iszero(x::Number) = iszero(x)
89100
_isone(x::Number) = isone(x)
90-
_iszero(::Term) = false
91-
_isone(::Term) = false
92-
93-
function simplifying_dot(x,y)
94-
isempty(x) && return 0
95-
muls = map(x,y) do xi,yi
96-
_isone(xi) ? yi : _isone(yi) ? xi : _iszero(xi) ? 0 : _iszero(yi) ? 0 : xi * yi
101+
_iszero(::Symbolic) = false
102+
_isone(::Symbolic) = false
103+
_iszero(x::Num) = _iszero(value(x))
104+
_isone(x::Num) = _isone(value(x))
105+
106+
LinearAlgebra.ldiv!(A::UpperTriangular{<:Union{Symbolic,Num}}, b::AbstractVector{<:Union{Symbolic,Num}}, x::AbstractVector{<:Union{Symbolic,Num}} = b) = symsub!(A, b, x)
107+
function symsub!(A::UpperTriangular, b::AbstractVector, x::AbstractVector = b)
108+
LinearAlgebra.require_one_based_indexing(A, b, x)
109+
n = size(A, 2)
110+
if !(n == length(b) == length(x))
111+
throw(DimensionMismatch("second dimension of left hand side A, $n, length of output x, $(length(x)), and length of right hand side b, $(length(b)), must be equal"))
97112
end
98-
99-
reduce(muls) do acc, x
100-
_iszero(acc) ? x : _iszero(x) ? acc : acc + x
113+
@inbounds for j in n:-1:1
114+
_iszero(A.data[j,j]) && throw(SingularException(j))
115+
xj = x[j] = b[j] / A.data[j,j]
116+
for i in j-1:-1:1
117+
sub = _isone(xj) ? A.data[i,j] : A.data[i,j] * xj
118+
if !_iszero(sub)
119+
b[i] -= sub
120+
end
121+
end
101122
end
123+
x
102124
end
103125

104-
function ldiv((L,U,p), b)
105-
m, n = size(L)
106-
x = Vector{Any}(undef, length(b))
107-
b = b[p]
108-
109-
for i=n:-1:1
110-
sub = simplifying_dot(x[i+1:end], U[i,i+1:end])
111-
den = U[i,i]
112-
x[i] = _iszero(sub) ? b[i] : b[i] - sub
113-
x[i] = _isone(den) ? x[i] : _isone(-den) ? -x[i] : x[i] / den
126+
LinearAlgebra.ldiv!(A::UnitLowerTriangular{<:Union{Symbolic,Num}}, b::AbstractVector{<:Union{Symbolic,Num}}, x::AbstractVector{<:Union{Symbolic,Num}} = b) = symsub!(A, b, x)
127+
function symsub!(A::UnitLowerTriangular, b::AbstractVector, x::AbstractVector = b)
128+
LinearAlgebra.require_one_based_indexing(A, b, x)
129+
n = size(A, 2)
130+
if !(n == length(b) == length(x))
131+
throw(DimensionMismatch("second dimension of left hand side A, $n, length of output x, $(length(x)), and length of right hand side b, $(length(b)), must be equal"))
114132
end
115-
116-
# unit lower triangular solve first:
117-
for i=1:n
118-
sub = simplifying_dot(b[1:i-1], L[i, 1:i-1]) # this should be `b` not x
119-
x[i] = _iszero(sub) ? x[i] : x[i] - sub
133+
@inbounds for j in 1:n
134+
xj = x[j] = b[j]
135+
for i in j+1:n
136+
sub = _isone(xj) ? A.data[i,j] : A.data[i,j] * xj
137+
if !_iszero(sub)
138+
b[i] -= sub
139+
end
140+
end
120141
end
121142
x
122143
end

test/operation_overloads.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using LinearAlgebra
33
using SparseArrays: sparse
44
using Test
55

6-
@variables a,b,c,d
6+
@variables a,b,c,d,e,f,g,h,i
77

88
# test hashing
99
aa = a; # old a
@@ -18,12 +18,22 @@ aa = a; # old a
1818
@test hash(a+b ~ c+d) == hash(a+b ~ c+d)
1919

2020
# test some matrix operations don't throw errors
21-
X = [a b;c d]
22-
det(X)
23-
lu(X)
21+
X = [0 b c; d e f; g h i]
22+
@test iszero(simplify(det(X) - ((b * f * g) + (c * d * h) - (b * d * i) - (c * e * g)), polynorm=true))
23+
F = lu(X)
24+
@test F.p == [2, 1, 3]
25+
R = simplify.(F.L * F.U - X[F.p, :], polynorm=true)
26+
@test iszero(R)
27+
@test simplify.(F \ X) == I
28+
@test ModelingToolkit._solve(X, X) == I
2429
inv(X)
2530
qr(X)
2631

32+
X2 = [0 b c; 0 0 0; 0 h 0]
33+
@test_throws SingularException lu(X2)
34+
F2 = lu(X2, check=false)
35+
@test F2.info == 1
36+
2737
# test operations with sparse arrays and Operations
2838
# note `isequal` instead of `==` because `==` would give another Operation
2939

0 commit comments

Comments
 (0)