Skip to content

Commit 68cc4bb

Browse files
committed
LU-based equation solve
1 parent 963ff11 commit 68cc4bb

File tree

4 files changed

+104
-37
lines changed

4 files changed

+104
-37
lines changed

src/ModelingToolkit.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ using Base: RefValue
1717
using RecursiveArrayTools
1818

1919
import SymbolicUtils
20-
import SymbolicUtils: to_symbolic, FnType, @rule, Rewriters
20+
import SymbolicUtils: to_symbolic, FnType, @rule, Rewriters, Term
21+
22+
using LinearAlgebra: LU, BlasInt
2123

2224
import LightGraphs: SimpleDiGraph, add_edge!
2325

@@ -90,6 +92,7 @@ include("function_registration.jl")
9092
include("simplify.jl")
9193
include("utils.jl")
9294
include("linearity.jl")
95+
include("solve.jl")
9396
include("direct.jl")
9497
include("domains.jl")
9598

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515
SymbolicUtils.promote_symtype(f::SymbolicUtils.Sym{FnType{X,Parameter{Y}}},
1616
xs...) where {X, Y} = Y
1717

18-
SymbolicUtils.arguments(x::Operation) = x.args
18+
SymbolicUtils.arguments(x::Operation) = to_symbolic.(x.args)
1919

2020
# SymbolicUtils wants raw numbers
2121
SymbolicUtils.to_symbolic(x::Constant) = x.value

src/solve.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Soft pivoted
2+
# Note: we call this function with a matrix of Union{SymbolicUtils.Symbolic, Any}
3+
# It should work as-is with Operation type too.
4+
function sym_lu(A)
5+
m, n = size(A)
6+
L = Array{Any}(undef, size(A)) # TODO: make sparse?
7+
for i=1:min(m, n)
8+
L[i,i] = 1
9+
end
10+
U = copy(A)
11+
p = BlasInt[1:m;]
12+
for k = 1:m-1
13+
i = k
14+
for ii=k:n
15+
if !iszero(U[ii, k])
16+
i = ii
17+
break
18+
end
19+
end
20+
# swap
21+
U[k, k:end], U[i, k:end] = U[i, k:end], U[k, k:end]
22+
L[k, 1:k-1], L[i, 1:k-1] = L[i, 1:k-1], L[k, 1:k-1]
23+
p[k] = i
24+
25+
for j = k+1:m
26+
L[j,k] = U[j, k] / U[k, k]
27+
U[j,k:m] .= U[j,k:m] .- simplifying_dot(L[j,k], U[k,k:m])
28+
end
29+
end
30+
factors = copy(U)
31+
for j=1:m
32+
for i=j+1:n
33+
factors[i,j] = L[i,j]
34+
end
35+
end
36+
37+
LU(factors, p, BlasInt(0))
38+
end
39+
40+
# Given a vector of equations and a
41+
# list of independent variables,
42+
# return the coefficient matrix `A` and a
43+
# vector of constants (possibly symbolic) `b` such that
44+
# A \ b will solve the equations for the vars
45+
function A_b(eqs, vars)
46+
exprs = rhss(eqs) .- lhss(eqs)
47+
for ex in exprs
48+
@assert islinear(ex, vars)
49+
end
50+
A = jacobian(exprs, vars)
51+
b = A * vars - exprs
52+
A, b
53+
end
54+
55+
function solve_for(eqs, vars)
56+
A, b = A_b(eqs, vars)
57+
A = SymbolicUtils.simplify.(to_symbolic.(A), polynorm=true)
58+
b = SymbolicUtils.simplify.(to_symbolic.(b), polynorm=true)
59+
map(to_mtk, SymbolicUtils.simplify.(ldiv(sym_lu(A), b)))
60+
end
61+
62+
# ldiv below
63+
64+
_iszero(x::Number) = iszero(x)
65+
_isone(x::Number) = isone(x)
66+
_iszero(::Term) = false
67+
_isone(::Term) = false
68+
69+
function simplifying_dot(x,y)
70+
isempty(x) && return 0
71+
muls = map(x,y) do xi,yi
72+
_isone(xi) ? yi : _isone(yi) ? xi : _iszero(xi) ? 0 : _iszero(yi) ? 0 : xi * yi
73+
end
74+
75+
reduce(muls) do acc, x
76+
_iszero(acc) ? x : _iszero(x) ? acc : acc + x
77+
end
78+
end
79+
80+
function ldiv(A::LU, b)
81+
# unit lower triangular solve first:
82+
L = A.L
83+
U = A.U
84+
m, n = size(L)
85+
x = Vector{Any}(undef, length(b))
86+
b = b[A.p]
87+
for i=1:n
88+
sub = simplifying_dot(b[1:i-1], L[i, 1:i-1])
89+
x[i] = _iszero(sub) ? b[i] : b[i] - sub
90+
end
91+
92+
for i=n:-1:1
93+
sub = simplifying_dot(b[i+1:end], U[i,i+1:end])
94+
den = U[i,i]
95+
x[i] = _iszero(sub) ? x[i] : x[i] - sub
96+
x[i] = _isone(U[i,i]) ? x[i] : x[i] / U[i,i]
97+
end
98+
x
99+
end

src/utils.jl

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -135,38 +135,3 @@ function substituter(pairs)
135135
end
136136

137137
@deprecate substitute_expr!(expr,s) substitute(expr,s)
138-
139-
# Really bad solve for vars
140-
function solve_for(eqs, vars)
141-
@assert length(eqs) >= length(vars)
142-
@assert all(iszero(eq.lhs) for eq in eqs)
143-
neweqs = []
144-
for (i, eq) in enumerate(eqs)
145-
rhs = eq.rhs
146-
if rhs.op == (-)
147-
if any(isequal(rhs.args[1]), vars) && any(isequal(rhs.args[2]), vars)
148-
push!(neweqs, rhs.args[1] ~ rhs.args[2]) # pick one?
149-
@warn("todo")
150-
elseif any(isequal(rhs.args[1]), vars)
151-
push!(neweqs, rhs.args[1] ~ rhs.args[2])
152-
elseif any(isequal(rhs.args[2]), vars)
153-
push!(neweqs, rhs.args[2] ~ rhs.args[1])
154-
else
155-
@warn("may require unimplemented solve")
156-
#error("todo 2")
157-
push!(neweqs, eq)
158-
end
159-
elseif rhs.op == (+)
160-
eqs[i] = 0 ~ rhs.args[1] - (-rhs.args[2])
161-
else
162-
error("todo")
163-
end
164-
end
165-
if length(neweqs) >= length(vars)
166-
return neweqs
167-
else
168-
# substitute
169-
eqs′ = Equation.(0, substitute.(rhss(eqs), (Dict(lhss(neweqs) .=> rhss(neweqs)),)))
170-
solve_for(eqs′, vars)
171-
end
172-
end

0 commit comments

Comments
 (0)