Skip to content

Commit 9654bb2

Browse files
committed
Update alias_elimination tests
1 parent 6466d8a commit 9654bb2

File tree

2 files changed

+23
-73
lines changed

2 files changed

+23
-73
lines changed

src/systems/reduction.jl

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -33,58 +33,6 @@ isvar(s::Sym; param=false) = param ? true : !isparameter(s)
3333
isvar(s::Term; param=false) = isvar(s.op; param=param)
3434
isvar(s::Any;param=false) = false
3535

36-
function filterexpr(f, s)
37-
vs = []
38-
Rewriters.Prewalk(Rewriters.Chain([@rule((~x::f) => push!(vs, ~x))]))(s)
39-
vs
40-
end
41-
42-
function make_lhs_0(eq)
43-
if eq.lhs isa Number && iszero(eq.lhs)
44-
return eq
45-
else
46-
0 ~ eq.lhs - eq.rhs
47-
end
48-
end
49-
50-
function alias_elimination(sys::ODESystem)
51-
eqs = vcat(equations(sys), observed(sys))
52-
53-
# make all algebraic equations have 0 on LHS
54-
eqs = map(eqs) do eq
55-
if eq.lhs isa Term && eq.lhs.op isa Differential
56-
eq
57-
else
58-
make_lhs_0(eq)
59-
end
60-
end
61-
62-
newstates = map(eqs) do eq
63-
if eq.lhs isa Term && eq.lhs.op isa Differential
64-
filterexpr(isvar, eq.lhs)
65-
else
66-
[]
67-
end
68-
end |> Iterators.flatten |> collect |> unique
69-
70-
71-
all_vars = map(eqs) do eq
72-
filterexpr(isvar, eq.rhs)
73-
end |> Iterators.flatten |> collect |> unique
74-
75-
alg_idxs = findall(x->!(x.lhs isa Term) && iszero(x.lhs), eqs)
76-
77-
eliminate = setdiff(all_vars, newstates)
78-
79-
outputs = solve_for(eqs[alg_idxs], eliminate)
80-
81-
diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)]
82-
83-
diffeqs′ = substitute_aliases(diffeqs, Dict(eliminate .=> outputs))
84-
85-
ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=eliminate .~ outputs)
86-
end
87-
8836
function get_α_x(αx)
8937
if isvar(αx, param=true)
9038
return 1, αx
@@ -101,7 +49,7 @@ function get_α_x(αx)
10149
end
10250
end
10351

104-
function alias_elimination2(sys)
52+
function alias_elimination(sys)
10553
eqs = vcat(equations(sys), observed(sys))
10654

10755
subs = Pair[]

test/reduction.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using ModelingToolkit, OrdinaryDiffEq, Test
2-
using ModelingToolkit: alias_elimination2
32

43
@parameters t σ ρ β
54
@variables x(t) y(t) z(t) a(t) u(t) F(t)
@@ -15,7 +14,7 @@ eqs = [D(x) ~ σ*(y-x),
1514

1615
lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1)
1716

18-
lorenz1_aliased = alias_elimination2(lorenz1)
17+
lorenz1_aliased = alias_elimination(lorenz1)
1918
@test length(equations(lorenz1_aliased)) == 3
2019
@test length(states(lorenz1_aliased)) == 3
2120

@@ -47,7 +46,7 @@ lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2)
4746
connections = [lorenz1.F ~ lorenz2.u,
4847
lorenz2.F ~ lorenz1.u]
4948

50-
connected = ODESystem([lorenz2.y ~ a + lorenz1.x ],t,[a],[],observed=connections,systems=[lorenz1,lorenz2])
49+
connected = ODESystem([lorenz2.y ~ a + lorenz1.x],t,[a],[],observed=connections,systems=[lorenz1,lorenz2])
5150

5251
# Reduced Unflattened System
5352
#=
@@ -62,7 +61,7 @@ connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1,
6261

6362
flattened_system = ModelingToolkit.flatten(connected)
6463

65-
aliased_flattened_system = alias_elimination2(flattened_system)
64+
aliased_flattened_system = alias_elimination(flattened_system)
6665

6766
@test isequal(states(aliased_flattened_system), [
6867
lorenz1.x
@@ -83,21 +82,24 @@ aliased_flattened_system = alias_elimination2(flattened_system)
8382
lorenz2.β
8483
]) |> isempty
8584

86-
test_equal.(equations(aliased_flattened_system), [
87-
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
88-
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
89-
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
90-
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
91-
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
92-
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z])
93-
94-
test_equal.(observed(aliased_flattened_system), [
95-
lorenz1.F ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
96-
lorenz1.u ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z,
97-
lorenz2.F ~ lorenz1.x + lorenz1.y + -1 * lorenz1.z,
98-
a ~ lorenz2.y + -1 * lorenz1.x,
99-
lorenz2.u ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
100-
])
85+
reduced_eqs = [
86+
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - (a + lorenz1.x) - lorenz2.z,
87+
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
88+
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
89+
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
90+
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
91+
D(lorenz2.z) ~ lorenz2.x*(a + lorenz1.x) - lorenz2.β*lorenz2.z
92+
]
93+
test_equal.(equations(aliased_flattened_system), reduced_eqs)
94+
95+
observed_eqs = [
96+
lorenz2.y ~ a + lorenz1.x,
97+
lorenz1.F ~ lorenz2.u,
98+
lorenz2.F ~ lorenz1.u,
99+
lorenz1.u ~ lorenz1.x + lorenz1.y - lorenz1.z,
100+
lorenz2.u ~ lorenz2.x - lorenz2.y - lorenz2.z,
101+
]
102+
test_equal.(observed(aliased_flattened_system), observed_eqs)
101103

102104

103105
# issue #578
@@ -110,7 +112,7 @@ let
110112
x ~ y
111113
];
112114
sys = ODESystem(eqs, t, [x], []);
113-
asys = alias_elimination2(ModelingToolkit.flatten(sys))
115+
asys = alias_elimination(ModelingToolkit.flatten(sys))
114116

115117
test_equal.(asys.eqs, [D(x) ~ 2x])
116118
test_equal.(asys.observed, [y ~ x])

0 commit comments

Comments
 (0)