Skip to content

Commit 62ecf95

Browse files
committed
fixes and first few tests
1 parent 85fb196 commit 62ecf95

File tree

2 files changed

+57
-37
lines changed

2 files changed

+57
-37
lines changed

src/systems/reduction.jl

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ function substitute_aliases(diffeqs, dict)
2727
lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (dict,))
2828
end
2929

30-
isvar(s::Sym) = !isparameter(s)
31-
isvar(s::Term) = isvar(s.op)
32-
isvar(s::Any) = false
30+
isvar(s::Sym; param=false) = param ? true : !isparameter(s)
31+
isvar(s::Term; param=false) = isvar(s.op; param=param)
32+
isvar(s::Any;param=false) = false
3333

3434
function filterexpr(f, s)
3535
vs = []
@@ -84,15 +84,15 @@ function alias_elimination(sys::ODESystem)
8484
end
8585

8686
function get_α_x(αx)
87-
if isvar(αx)
88-
return αx, 1
87+
if isvar(αx, param=true)
88+
return 1, αx
8989
elseif αx isa Term && operation(αx) === (*)
9090
args = arguments(αx)
9191
nums = filter(!isvar, args)
9292
syms = filter(isvar, args)
9393

9494
if length(syms) == 1
95-
return syms[1], prod(nums)
95+
return prod(nums), syms[1]
9696
end
9797
else
9898
return nothing
@@ -105,51 +105,68 @@ function alias_elimination2(sys)
105105
subs = Pair[]
106106
# Case 1: Right hand side is a constant
107107
ii = findall(eqs) do eq
108-
(eq.lhs isa Sym || (eq.lhs isa Term && !(eq.lhs.op isa Differential))) && !(eq.rhs isa Symbolic)
108+
!(eq.rhs isa Symbolic)
109109
end
110110
for eq in eqs[ii]
111-
substitution_dict[eq.lhs] = eq.rhs
112-
push!(subs, eq.lhs => eq.rhs)
111+
α,x = get_α_x(eq.lhs)
112+
push!(subs, x => isone(α) ? eq.rhs : eq.rhs / α)
113113
end
114114
deleteat!(eqs, ii) # remove them
115115

116116
# Case 2: One side is a differentiated var, the other is an algebraic var
117117
# substitute the algebraic var with the diff var
118-
diff_vars = findall(eqs) do eq
119-
if eq.lhs isa Term && eq.lhs.op isa Differential
120-
eq.lhs.args[1]
121-
else
122-
nothing
123-
end
124-
end
118+
diff_vars = filter(!isnothing, map(eqs) do eq
119+
if eq.lhs isa Term && eq.lhs.op isa Differential
120+
eq.lhs.args[1]
121+
else
122+
nothing
123+
end
124+
end) |> Set
125125

126-
for eq in eqs
126+
del = Int[]
127+
for (i, eq) in enumerate(eqs)
127128
res_left = get_α_x(eq.lhs)
128-
if !isnothing(res)
129+
if !isnothing(res_left)
130+
α, x = res_left
129131
res_right = get_α_x(eq.rhs)
130-
β, y = res
131-
if y in diff_vars && !(x in diff_vars)
132-
multiple = β / α
133-
push!(subs, x => isone(multiple) ? y : multiple * y)
134-
elseif x in diff_vars && !(y in diff_vars)
135-
multiple = α / β
136-
push!(subs, y => isone(multiple) ? y : multiple * y)
132+
if !isnothing(res_right)
133+
β, y = res_right
134+
if y in diff_vars && !(x in diff_vars)
135+
multiple = β / α
136+
push!(subs, x => isone(multiple) ? y : multiple * y)
137+
push!(del, i)
138+
elseif x in diff_vars && !(y in diff_vars)
139+
multiple = α / β
140+
push!(subs, y => isone(multiple) ? x : multiple * x)
141+
push!(del, i)
142+
end
137143
end
138144
end
139145
end
146+
deleteat!(eqs, del)
140147

141148
# Case 3: Explicit substitutions
142-
for eq in eqs
149+
del = Int[]
150+
for (i, eq) in enumerate(eqs)
143151
res_left = get_α_x(eq.lhs)
144-
if !isnothing(res)
152+
if !isnothing(res_left)
153+
α, x = res_left
145154
res_right = get_α_x(eq.rhs)
146-
β, y = res
147-
multiple = β / α
148-
push!(subs, x => isone(multiple) ? x : multiple * x)
155+
if !isnothing(res_right)
156+
β, y = res_right
157+
multiple = β / α
158+
push!(subs, x => _isone(multiple) ? x : multiple * x)
159+
push!(del, i)
160+
end
149161
end
150162
end
163+
deleteat!(eqs, del)
151164

152165
diffeqs = filter(eq -> eq.lhs isa Term && eq.lhs.op isa Differential, eqs)
153166
diffeqs′ = substitute_aliases(diffeqs, Dict(subs))
167+
168+
newstates = map(diffeqs) do eq
169+
eq.lhs.args[1]
170+
end
154171
ODESystem(diffeqs′, sys.iv, newstates, parameters(sys), observed=first.(subs) .~ last.(subs))
155172
end

test/reduction.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit, OrdinaryDiffEq, Test
2+
using ModelingToolkit: alias_elimination2
23

34
@parameters t σ ρ β
45
@variables x(t) y(t) z(t) a(t) u(t) F(t)
@@ -9,19 +10,21 @@ test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynor
910
eqs = [D(x) ~ σ*(y-x),
1011
D(y) ~ x*-z)-y,
1112
D(z) ~ a*y - β*z,
12-
0 ~ x - a]
13+
β ~ 2,
14+
x ~ a]
1315

1416
lorenz1 = ODESystem(eqs,t,[x,y,z,a],[σ,ρ,β],name=:lorenz1)
1517

16-
lorenz1_aliased = alias_elimination(lorenz1)
18+
lorenz1_aliased = alias_elimination2(lorenz1)
1719
@test length(equations(lorenz1_aliased)) == 3
1820
@test length(states(lorenz1_aliased)) == 3
1921

2022
eqs = [D(x) ~ σ*(y-x),
2123
D(y) ~ x*-z)-y,
22-
D(z) ~ x*y - β*z]
24+
D(z) ~ x*y - 2*z]
2325

24-
@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=[a ~ x],name=:lorenz1)
26+
# TODO: maybe remove β from ps, or maybe don't allow this example on params
27+
@test lorenz1_aliased == ODESystem(eqs,t,[x,y,z],[σ,ρ,β],observed=~ 2, a ~ x],name=:lorenz1)
2528

2629
# Multi-System Reduction
2730

@@ -44,7 +47,7 @@ lorenz2 = ODESystem(eqs2,pins=[F],observed=aliases2,name=:lorenz2)
4447
connections = [lorenz1.F ~ lorenz2.u,
4548
lorenz2.F ~ lorenz1.u]
4649

47-
connected = ODESystem([0 ~ a + lorenz1.x - lorenz2.y],t,[a],[],observed=connections,systems=[lorenz1,lorenz2])
50+
connected = ODESystem([lorenz2.y ~ a + lorenz1.x ],t,[a],[],observed=connections,systems=[lorenz1,lorenz2])
4851

4952
# Reduced Unflattened System
5053
#=
@@ -59,7 +62,7 @@ connected = ODESystem(Equation[],t,[],[],observed=connections2,systems=[lorenz1,
5962

6063
flattened_system = ModelingToolkit.flatten(connected)
6164

62-
aliased_flattened_system = alias_elimination(flattened_system)
65+
aliased_flattened_system = alias_elimination2(flattened_system)
6366

6467
@test isequal(states(aliased_flattened_system), [
6568
lorenz1.x
@@ -107,7 +110,7 @@ let
107110
x ~ y
108111
];
109112
sys = ODESystem(eqs, t, [x], []);
110-
asys = alias_elimination(ModelingToolkit.flatten(sys))
113+
asys = alias_elimination2(ModelingToolkit.flatten(sys))
111114

112115
test_equal.(asys.eqs, [D(x) ~ 2x])
113116
test_equal.(asys.observed, [y ~ x])

0 commit comments

Comments
 (0)