Skip to content

Commit 58631b7

Browse files
committed
fixpoint the substitution of variables
1 parent 791115e commit 58631b7

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

src/systems/reduction.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,24 @@ export alias_elimination
33
function flatten(sys::ODESystem)
44
ODESystem(equations(sys),
55
independent_variable(sys),
6-
states(sys),
7-
parameters(sys),
86
observed=observed(sys))
97
end
108

119

10+
using SymbolicUtils: Rewriters
11+
12+
function fixpoint_sub(x, dict)
13+
y = substitute(x, dict)
14+
while !isequal(x, y)
15+
y = x
16+
x = substitute(y, dict)
17+
end
18+
19+
return x
20+
end
21+
1222
function substitute_aliases(diffeqs, outputs)
13-
lhss(diffeqs) .~ substitute.(rhss(diffeqs), (Dict(lhss(outputs) .=> rhss(outputs)),))
23+
lhss(diffeqs) .~ fixpoint_sub.(rhss(diffeqs), (Dict(lhss(outputs) .=> rhss(outputs)),))
1424
end
1525

1626
function make_lhs_0(eq)
@@ -32,11 +42,17 @@ function alias_elimination(sys::ODESystem)
3242
[]
3343
end
3444
end |> Iterators.flatten |> collect |> unique
45+
46+
all_vars = map(eqs) do eq
47+
filter(x->!isparameter(x.op), get_variables(eq.rhs))
48+
end |> Iterators.flatten |> collect |> unique
49+
3550
newstates = convert.(Variable, new_stateops)
3651

3752
alg_idxs = findall(x->x.lhs isa Constant && iszero(x.lhs), eqs)
3853

39-
eliminate = setdiff(states(sys), newstates)
54+
eliminate = setdiff(convert.(Variable, all_vars), newstates)
55+
4056
outputs = solve_for(eqs[alg_idxs], map(x->x(sys.iv()), eliminate))
4157

4258
diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)]

test/reduction.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
using ModelingToolkit, OrdinaryDiffEq, Test
22

3-
@parameters t σ ρ β F(t)
4-
@variables x(t) y(t) z(t) a(t) u(t)
3+
@parameters t σ ρ β
4+
@variables x(t) y(t) z(t) a(t) u(t) F(t)
55
@derivatives D'~t
66

7+
test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true))
8+
79
eqs = [D(x) ~ σ*(y-x),
810
D(y) ~ x*-z)-y,
911
D(z) ~ a*y - β*z,
@@ -78,16 +80,16 @@ aliased_flattened_system = alias_elimination(flattened_system)
7880
lorenz2.β
7981
])) |> isempty
8082

81-
#=
82-
equations(reduced_flattened_system) == [
83+
test_equal.(equations(aliased_flattened_system), [
8384
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x) + lorenz2.x - lorenz2.y - lorenz2.z,
84-
D(lorenz1.y) ~ lorenz1.x*(ρ-z)-lorenz1.x - lorenz1.y + lorenz1.z,
85-
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z
85+
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-(lorenz1.x + lorenz1.y - lorenz1.z),
86+
D(lorenz1.z) ~ lorenz1.x*lorenz1.y - lorenz1.β*lorenz1.z,
8687
D(lorenz2.x) ~ lorenz1.x + lorenz1.y - lorenz1.z,
8788
D(lorenz2.y) ~ lorenz2.x*(lorenz2.ρ-lorenz2.z)-lorenz2.x,
88-
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z]
89+
D(lorenz2.z) ~ lorenz2.x*lorenz2.y - lorenz2.β*lorenz2.z])
8990

90-
observed(reduced_flattened_system) == [
91+
#=
92+
observed(aliased_flattened_system) == [
9193
lorenz1.F ~ lorenz2.u
9294
lorenz2.F ~ lorenz1.u
9395
a ~ -lorenz1.x + lorenz2.y

0 commit comments

Comments
 (0)