Skip to content

Commit 87cd5cf

Browse files
committed
Better LU factorization and linear solve
1 parent 580aa6d commit 87cd5cf

File tree

2 files changed

+58
-49
lines changed

2 files changed

+58
-49
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ Base.convert(::Type{<:Array{Num}}, x::AbstractArray) = map(Num, x)
123123
Base.convert(::Type{<:Array{Num}}, x::AbstractArray{Num}) = x
124124
Base.convert(::Type{Sym}, x::Num) = value(x) isa Sym ? value(x) : error("cannot convert $x to Sym")
125125

126-
LinearAlgebra.lu(x::Array{Num}; kw...) = lu(x, Val{false}(); kw...)
126+
LinearAlgebra.lu(x::Array{Num}; kw...) = sym_lu(x)
127127

128128
"""
129129
$(TYPEDEF)

src/solve.jl

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,31 @@ end
1212
# It should work as-is with Operation type too.
1313
function sym_lu(A)
1414
m, n = size(A)
15-
L = fill!(Array{Any}(undef, size(A)),0) # TODO: make sparse?
16-
for i=1:min(m, n)
17-
L[i,i] = 1
18-
end
19-
U = copy!(Array{Any}(undef, size(A)),A)
20-
p = BlasInt[1:m;]
21-
for k = 1:m-1
22-
_, i = findmin(map(ii->_iszero(U[ii, k]) ? Inf : nterms(U[ii,k]), k:n))
15+
F = map(x->x isa Num ? x : Num(x), A)
16+
minmn = min(m, n)
17+
p = Vector{BlasInt}(undef, minmn)
18+
info = zero(BlasInt)
19+
for k = 1:minmn
20+
val, i = findmin(map(ii->_iszero(F[ii, k]) ? Inf : nterms(F[ii,k]), k:n))
21+
if !(val isa Symbolic) && (val isa Number) && val == Inf && iszero(info)
22+
info = k
23+
end
2324
i += k - 1
2425
# swap
25-
U[k, k:end], U[i, k:end] = U[i, k:end], U[k, k:end]
26-
L[k, 1:k-1], L[i, 1:k-1] = L[i, 1:k-1], L[k, 1:k-1]
26+
for j in 1:n
27+
F[k, j], F[i, j] = F[i, j], F[k, j]
28+
end
2729
p[k] = i
28-
29-
for j = k+1:m
30-
L[j,k] = U[j, k] / U[k, k]
31-
U[j,k:m] .= U[j,k:m] .- (L[j,k],) .* U[k,k:m]
30+
for i in k+1:m
31+
F[i, k] = F[i, k] / F[k, k]
3232
end
33-
end
34-
for j=1:m
35-
for i=j+1:n
36-
U[i,j] = 0
33+
for j = k+1:n
34+
for i in k+1:m
35+
F[i, j] = F[i, j] - F[i, k] * F[k, j]
36+
end
3737
end
3838
end
39-
40-
(L, U, LinearAlgebra.ipiv2perm(p, m))
39+
LU(F, p, info)
4140
end
4241

4342
# Given a vector of equations and a
@@ -77,46 +76,56 @@ function solve_for(eqs, vars)
7776
end
7877

7978
function _solve(A::AbstractMatrix, b::AbstractArray)
80-
A = SymbolicUtils.simplify.(to_symbolic.(A), polynorm=true)
81-
b = SymbolicUtils.simplify.(to_symbolic.(b), polynorm=true)
82-
SymbolicUtils.simplify.(ldiv(sym_lu(A), b))
79+
A = SymbolicUtils.simplify.(Num.(A), polynorm=true)
80+
b = SymbolicUtils.simplify.(Num.(b), polynorm=true)
81+
value.(SymbolicUtils.simplify.(sym_lu(A) \ b))
8382
end
8483
_solve(a, b) = value(SymbolicUtils.simplify(b/a, polynorm=true))
8584

8685
# ldiv below
8786

8887
_iszero(x::Number) = iszero(x)
8988
_isone(x::Number) = isone(x)
90-
_iszero(::Term) = false
91-
_isone(::Term) = false
92-
93-
function simplifying_dot(x,y)
94-
isempty(x) && return 0
95-
muls = map(x,y) do xi,yi
96-
_isone(xi) ? yi : _isone(yi) ? xi : _iszero(xi) ? 0 : _iszero(yi) ? 0 : xi * yi
89+
_iszero(::Symbolic) = false
90+
_isone(::Symbolic) = false
91+
_iszero(x::Num) = _iszero(value(x))
92+
_isone(x::Num) = _isone(value(x))
93+
94+
LinearAlgebra.ldiv!(A::UpperTriangular{<:Union{Symbolic,Num}}, b::AbstractVector{<:Union{Symbolic,Num}}, x::AbstractVector{<:Union{Symbolic,Num}} = b) = symsub!(A, b, x)
95+
function symsub!(A::UpperTriangular, b::AbstractVector, x::AbstractVector = b)
96+
LinearAlgebra.require_one_based_indexing(A, b, x)
97+
n = size(A, 2)
98+
if !(n == length(b) == length(x))
99+
throw(DimensionMismatch("second dimension of left hand side A, $n, length of output x, $(length(x)), and length of right hand side b, $(length(b)), must be equal"))
97100
end
98-
99-
reduce(muls) do acc, x
100-
_iszero(acc) ? x : _iszero(x) ? acc : acc + x
101+
@inbounds for j in n:-1:1
102+
_iszero(A.data[j,j]) && throw(SingularException(j))
103+
xj = x[j] = b[j] / A.data[j,j]
104+
for i in j-1:-1:1
105+
sub = _isone(xj) ? A.data[i,j] : A.data[i,j] * xj
106+
if !_iszero(sub)
107+
b[i] -= sub
108+
end
109+
end
101110
end
111+
x
102112
end
103113

104-
function ldiv((L,U,p), b)
105-
m, n = size(L)
106-
x = Vector{Any}(undef, length(b))
107-
b = b[p]
108-
109-
for i=n:-1:1
110-
sub = simplifying_dot(x[i+1:end], U[i,i+1:end])
111-
den = U[i,i]
112-
x[i] = _iszero(sub) ? b[i] : b[i] - sub
113-
x[i] = _isone(den) ? x[i] : _isone(-den) ? -x[i] : x[i] / den
114+
LinearAlgebra.ldiv!(A::UnitLowerTriangular{<:Union{Symbolic,Num}}, b::AbstractVector{<:Union{Symbolic,Num}}, x::AbstractVector{<:Union{Symbolic,Num}} = b) = symsub!(A, b, x)
115+
function symsub!(A::UnitLowerTriangular, b::AbstractVector, x::AbstractVector = b)
116+
LinearAlgebra.require_one_based_indexing(A, b, x)
117+
n = size(A, 2)
118+
if !(n == length(b) == length(x))
119+
throw(DimensionMismatch("second dimension of left hand side A, $n, length of output x, $(length(x)), and length of right hand side b, $(length(b)), must be equal"))
114120
end
115-
116-
# unit lower triangular solve first:
117-
for i=1:n
118-
sub = simplifying_dot(b[1:i-1], L[i, 1:i-1]) # this should be `b` not x
119-
x[i] = _iszero(sub) ? x[i] : x[i] - sub
121+
@inbounds for j in 1:n
122+
xj = x[j] = b[j]
123+
for i in j+1:n
124+
sub = _isone(xj) ? A.data[i,j] : A.data[i,j] * xj
125+
if !_iszero(sub)
126+
b[i] -= sub
127+
end
128+
end
120129
end
121130
x
122131
end

0 commit comments

Comments
 (0)