Skip to content

Commit dba32a9

Browse files
Merge pull request #546 from SciML/reduction
automatic equation reduction
2 parents 97c95f3 + e65eb21 commit dba32a9

File tree

9 files changed

+308
-13
lines changed

9 files changed

+308
-13
lines changed

src/ModelingToolkit.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ using Base: RefValue
1818
using RecursiveArrayTools
1919

2020
import SymbolicUtils
21-
import SymbolicUtils: to_symbolic, FnType, @rule, Rewriters
21+
import SymbolicUtils: to_symbolic, FnType, @rule, Rewriters, Term
22+
23+
using LinearAlgebra: LU, BlasInt
2224

2325
import LightGraphs: SimpleDiGraph, add_edge!
2426

@@ -91,6 +93,7 @@ include("function_registration.jl")
9193
include("simplify.jl")
9294
include("utils.jl")
9395
include("linearity.jl")
96+
include("solve.jl")
9497
include("direct.jl")
9598
include("domains.jl")
9699

@@ -116,6 +119,8 @@ include("systems/pde/pdesystem.jl")
116119
include("systems/reaction/reactionsystem.jl")
117120
include("systems/dependency_graphs.jl")
118121

122+
include("systems/reduction.jl")
123+
119124
include("latexify_recipes.jl")
120125
include("build_function.jl")
121126

src/simplify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515
SymbolicUtils.promote_symtype(f::SymbolicUtils.Sym{FnType{X,Parameter{Y}}},
1616
xs...) where {X, Y} = Y
1717

18-
SymbolicUtils.arguments(x::Operation) = x.args
18+
SymbolicUtils.arguments(x::Operation) = to_symbolic.(x.args)
1919

2020
# SymbolicUtils wants raw numbers
2121
SymbolicUtils.to_symbolic(x::Constant) = x.value

src/solve.jl

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
using SymbolicUtils: istree
2+
3+
function nterms(t)
4+
if istree(t)
5+
return reduce(+, map(nterms, arguments(t)), init=0)
6+
else
7+
return 1
8+
end
9+
end
10+
# Soft pivoted
11+
# Note: we call this function with a matrix of Union{SymbolicUtils.Symbolic, Any}
12+
# It should work as-is with Operation type too.
13+
function sym_lu(A)
14+
m, n = size(A)
15+
L = fill!(Array{Any}(undef, size(A)),0) # TODO: make sparse?
16+
for i=1:min(m, n)
17+
L[i,i] = 1
18+
end
19+
U = copy!(Array{Any}(undef, size(A)),A)
20+
p = BlasInt[1:m;]
21+
for k = 1:m-1
22+
_, i = findmin(map(ii->iszero(U[ii, k]) ? Inf : nterms(U[ii,k]), k:n))
23+
i += k - 1
24+
# swap
25+
U[k, k:end], U[i, k:end] = U[i, k:end], U[k, k:end]
26+
L[k, 1:k-1], L[i, 1:k-1] = L[i, 1:k-1], L[k, 1:k-1]
27+
p[k] = i
28+
29+
for j = k+1:m
30+
L[j,k] = U[j, k] / U[k, k]
31+
U[j,k:m] .= U[j,k:m] .- L[j,k] .* U[k,k:m]
32+
end
33+
end
34+
for j=1:m
35+
for i=j+1:n
36+
U[i,j] = 0
37+
end
38+
end
39+
40+
(L, U, LinearAlgebra.ipiv2perm(p, m))
41+
end
42+
43+
# Given a vector of equations and a
44+
# list of independent variables,
45+
# return the coefficient matrix `A` and a
46+
# vector of constants (possibly symbolic) `b` such that
47+
# A \ b will solve the equations for the vars
48+
function A_b(eqs, vars)
49+
exprs = rhss(eqs) .- lhss(eqs)
50+
for ex in exprs
51+
@assert islinear(ex, vars)
52+
end
53+
A = jacobian(exprs, vars)
54+
b = A * vars - exprs
55+
A, b
56+
end
57+
58+
"""
59+
solve_for(eqs::Vector, vars::Vector)
60+
61+
Solve the vector of equations `eqs` for a set of variables `vars`.
62+
63+
Assumes `length(eqs) == length(vars)`
64+
65+
Currently only works if all equations are linear.
66+
"""
67+
function solve_for(eqs, vars)
68+
A, b = A_b(eqs, vars)
69+
_solve(A, b)
70+
end
71+
72+
function _solve(A, b)
73+
A = SymbolicUtils.simplify.(to_symbolic.(A), polynorm=true)
74+
b = SymbolicUtils.simplify.(to_symbolic.(b), polynorm=true)
75+
map(to_mtk, SymbolicUtils.simplify.(ldiv(sym_lu(A), b)))
76+
end
77+
78+
LinearAlgebra.:(\)(A::AbstractMatrix{<:Expression}, b::AbstractVector{<:Expression}) = _solve(A, b)
79+
LinearAlgebra.:(\)(A::AbstractMatrix{<:Expression}, b::AbstractVector) = _solve(A, b)
80+
LinearAlgebra.:(\)(A::AbstractMatrix, b::AbstractVector{<:Expression}) = _solve(A, b)
81+
82+
# ldiv below
83+
84+
_iszero(x::Number) = iszero(x)
85+
_isone(x::Number) = isone(x)
86+
_iszero(::Term) = false
87+
_isone(::Term) = false
88+
89+
function simplifying_dot(x,y)
90+
isempty(x) && return 0
91+
muls = map(x,y) do xi,yi
92+
_isone(xi) ? yi : _isone(yi) ? xi : _iszero(xi) ? 0 : _iszero(yi) ? 0 : xi * yi
93+
end
94+
95+
reduce(muls) do acc, x
96+
_iszero(acc) ? x : _iszero(x) ? acc : acc + x
97+
end
98+
end
99+
100+
function ldiv((L,U,p), b)
101+
m, n = size(L)
102+
x = Vector{Any}(undef, length(b))
103+
b = b[p]
104+
105+
for i=n:-1:1
106+
sub = simplifying_dot(x[i+1:end], U[i,i+1:end])
107+
den = U[i,i]
108+
x[i] = _iszero(sub) ? b[i] : b[i] - sub
109+
x[i] = _isone(den) ? x[i] : _isone(-den) ? -x[i] : x[i] / den
110+
end
111+
112+
# unit lower triangular solve first:
113+
for i=1:n
114+
sub = simplifying_dot(b[1:i-1], L[i, 1:i-1]) # this should be `b` not x
115+
x[i] = _iszero(sub) ? x[i] : x[i] - sub
116+
end
117+
x
118+
end

src/systems/abstractsystem.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ pins(sys::AbstractSystem) = isempty(sys.systems) ? sys.pins : [sys.pins;reduce(v
193193
function observed(sys::AbstractSystem)
194194
[sys.observed;
195195
reduce(vcat,
196-
(namespace_equation.(s.observed, s.name, s.iv.name) for s in sys.systems),
196+
(namespace_equation.(observed(s), s.name, s.iv.name) for s in sys.systems),
197197
init=Equation[])]
198198
end
199199

@@ -215,7 +215,7 @@ end
215215
lhss(xs) = map(x->x.lhs, xs)
216216
rhss(xs) = map(x->x.rhs, xs)
217217

218-
function equations(sys::ModelingToolkit.AbstractSystem; remove_aliases = true)
218+
function equations(sys::ModelingToolkit.AbstractSystem)
219219
if isempty(sys.systems)
220220
return sys.eqs
221221
else
@@ -224,14 +224,7 @@ function equations(sys::ModelingToolkit.AbstractSystem; remove_aliases = true)
224224
namespace_equations.(sys.systems);
225225
init=Equation[])]
226226

227-
if !remove_aliases
228-
return eqs
229-
end
230-
aliases = observed(sys)
231-
dict = Dict(lhss(aliases) .=> rhss(aliases))
232-
233-
# Substitute aliases
234-
return Equation.(lhss(eqs), Rewriters.Fixpoint(x->substitute(x, dict)).(rhss(eqs)))
227+
return eqs
235228
end
236229
end
237230

src/systems/reduction.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
export alias_elimination
2+
3+
function flatten(sys::ODESystem)
4+
ODESystem(equations(sys),
5+
independent_variable(sys),
6+
observed=observed(sys))
7+
end
8+
9+
10+
using SymbolicUtils: Rewriters
11+
12+
function fixpoint_sub(x, dict)
13+
y = substitute(x, dict)
14+
while !isequal(x, y)
15+
y = x
16+
x = substitute(y, dict)
17+
end
18+
19+
return x
20+
end
21+
22+
function substitute_aliases(diffeqs, dict)
23+
lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,))
24+
end
25+
26+
function make_lhs_0(eq)
27+
if eq.lhs isa Constant && iszero(eq.lhs)
28+
return eq
29+
else
30+
0 ~ eq.lhs - eq.rhs
31+
end
32+
end
33+
34+
function alias_elimination(sys::ODESystem)
35+
eqs = vcat(equations(sys),
36+
make_lhs_0.(observed(sys)))
37+
38+
new_stateops = map(eqs) do eq
39+
if eq.lhs isa Operation && eq.lhs.op isa Differential
40+
get_variables(eq.lhs)
41+
else
42+
[]
43+
end
44+
end |> Iterators.flatten |> collect |> unique
45+
46+
all_vars = map(eqs) do eq
47+
filter(x->!isparameter(x.op), get_variables(eq.rhs))
48+
end |> Iterators.flatten |> collect |> unique
49+
50+
newstates = convert.(Variable, new_stateops)
51+
52+
alg_idxs = findall(x->x.lhs isa Constant && iszero(x.lhs), eqs)
53+
54+
eliminate = setdiff(convert.(Variable, all_vars), newstates)
55+
56+
vars = map(x->x(sys.iv()), eliminate)
57+
outputs = solve_for(eqs[alg_idxs], vars)
58+
59+
diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)]
60+
61+
diffeqs′ = substitute_aliases(diffeqs, Dict(vars .=> outputs))
62+
63+
ODESystem(diffeqs′, sys.iv(), new_stateops, parameters(sys), observed=vars .~ outputs)
64+
end
65+

src/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,18 @@ substitute(expr::Operation, s::Union{Vector, Dict}) = substituter(s)(expr)
131131

132132
function substituter(pairs)
133133
dict = Dict(to_symbolic(k) => to_symbolic(v) for (k, v) in pairs)
134-
expr -> to_mtk(SymbolicUtils.simplify(SymbolicUtils.substitute(expr, dict)))
134+
expr -> to_mtk(SymbolicUtils.substitute(expr, dict))
135+
end
136+
137+
macro showarr(x)
138+
n = string(x)
139+
quote
140+
y = $(esc(x))
141+
println($n, " = ", summary(y))
142+
Base.print_array(stdout, y)
143+
println()
144+
y
145+
end
135146
end
136147

137148
@deprecate substitute_expr!(expr,s) substitute(expr,s)

test/operation_overloads.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ M \ b
7070
M \ reshape(b,2,1)
7171
M = [1 1; 0 2]
7272
M \ reshape(b,2,1)
73+
74+
75+
M = [1 a; 0 2]
76+
M \ b
77+
M \ [1, 2]

test/reduction.jl

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
using ModelingToolkit, OrdinaryDiffEq, Test
2+
3+
@parameters t σ ρ β
4+
@variables x(t) y(t) z(t) a(t) u(t) F(t)
5+
@derivatives D'~t
6+
7+
test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true))
8+
9+
eqs = [D(x) ~ σ*(y-x),
10+
D(y) ~ x*-z)-y,
11+
D(z) ~ a*y - β*z,
12+
0 ~ x - a]
13+
14+
lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1)
15+
16+
lorenz1_aliased = alias_elimination(lorenz1)
17+
@test length(equations(lorenz1_aliased)) == 3
18+
@test length(states(lorenz1_aliased)) == 3
19+
20+
eqs = [D(x) ~ σ*(y-x),
21+
D(y) ~ x*-z)-y,
22+
D(z) ~ x*y - β*z]
23+
24+
@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=[a ~ x],name=:lorenz1)
25+
26+
# Multi-System Reduction
27+
28+
eqs1 = [D(x) ~ σ*(y-x) + F,
29+
D(y) ~ x*-z)-u,
30+
D(z) ~ x*y - β*z]
31+
32+
aliases = [u ~ x + y - z]
33+
34+
lorenz1 = ODESystem(eqs1,pins=[F],observed=aliases,name=:lorenz1)
35+
36+
eqs2 = [D(x) ~ F,
37+
D(y) ~ x*-z)-x,
38+
D(z) ~ x*y - β*z]
39+
40+
aliases2 = [u ~ x - y - z]
41+
42+
lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2)
43+
44+
connections = [lorenz1.F ~ lorenz2.u,
45+
lorenz2.F ~ lorenz1.u]
46+
47+
connected = ODESystem([0 ~ a + lorenz1.x - lorenz2.y],t,[a],[],observed=connections,systems=[lorenz1,lorenz2])
48+
49+
# Reduced Unflattened System
50+
#=
51+
52+
connections2 = [lorenz1.F ~ lorenz2.u,
53+
lorenz2.F ~ lorenz1.u,
54+
a ~ -lorenz1.x + lorenz2.y]
55+
connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1,lorenz2])
56+
=#
57+
58+
# Reduced Flattened System
59+
60+
flattened_system = ModelingToolkit.flatten(connected)
61+
62+
aliased_flattened_system = alias_elimination(flattened_system)
63+
64+
@test states(aliased_flattened_system) == convert.(Variable, [
65+
lorenz1.x
66+
lorenz1.y
67+
lorenz1.z
68+
lorenz2.x
69+
lorenz2.y
70+
lorenz2.z
71+
])
72+
73+
@test setdiff(parameters(aliased_flattened_system), convert.(Variable, [
74+
lorenz1.σ
75+
lorenz1.ρ
76+
lorenz1.β
77+
lorenz1.F
78+
lorenz2.F
79+
lorenz2.ρ
80+
lorenz2.β
81+
])) |> isempty
82+
83+
test_equal.(equations(aliased_flattened_system), [
84+
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
85+
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
86+
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
87+
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
88+
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
89+
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z])
90+
91+
test_equal.(observed(aliased_flattened_system), [
92+
lorenz1.F ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
93+
lorenz1.u ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z,
94+
lorenz2.F ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z,
95+
a ~ lorenz2.y + -1 * lorenz1.x,
96+
lorenz2.u ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
97+
])

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ using SafeTestsets, Test
2020
@safetestset "Build Targets Test" begin include("build_targets.jl") end
2121
@safetestset "Domain Test" begin include("domains.jl") end
2222
@safetestset "Constraints Test" begin include("constraints.jl") end
23+
@safetestset "Reduction Test" begin include("reduction.jl") end
2324
@safetestset "PDE Construction Test" begin include("pde.jl") end
2425
@safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end
2526
@safetestset "Test Big System Usage" begin include("bigsystem.jl") end

0 commit comments

Comments
 (0)