Skip to content

Commit 6a9b682

Browse files
authored
Merge pull request #582 from SciML/s/fix578
fix #578
2 parents f1885a9 + 363a1a8 commit 6a9b682

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

src/systems/reduction.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
export alias_elimination
22

33
function flatten(sys::ODESystem)
4-
ODESystem(equations(sys),
5-
independent_variable(sys),
6-
observed=observed(sys))
4+
if isempty(sys.systems)
5+
return sys
6+
else
7+
return ODESystem(equations(sys),
8+
independent_variable(sys),
9+
observed=observed(sys))
10+
end
711
end
812

913

@@ -32,8 +36,16 @@ function make_lhs_0(eq)
3236
end
3337

3438
function alias_elimination(sys::ODESystem)
35-
eqs = vcat(equations(sys),
36-
make_lhs_0.(observed(sys)))
39+
eqs = vcat(equations(sys), observed(sys))
40+
41+
# make all algebraic equations have 0 on LHS
42+
eqs = map(eqs) do eq
43+
if eq.lhs isa Operation && eq.lhs.op isa Differential
44+
eq
45+
else
46+
make_lhs_0(eq)
47+
end
48+
end
3749

3850
new_stateops = map(eqs) do eq
3951
if eq.lhs isa Operation && eq.lhs.op isa Differential
@@ -49,11 +61,13 @@ function alias_elimination(sys::ODESystem)
4961

5062
newstates = convert.(Variable, new_stateops)
5163

64+
5265
alg_idxs = findall(x->x.lhs isa Constant && iszero(x.lhs), eqs)
5366

5467
eliminate = setdiff(convert.(Variable, all_vars), newstates)
5568

5669
vars = map(x->x(sys.iv()), eliminate)
70+
5771
outputs = solve_for(eqs[alg_idxs], vars)
5872

5973
diffeqs = eqs[setdiff(1:length(eqs), alg_idxs)]

test/reduction.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,20 @@ test_equal.(observed(aliased_flattened_system), [
9595
a ~ lorenz2.y + -1 * lorenz1.x,
9696
lorenz2.u ~ lorenz2.x + -1 * (lorenz2.y + lorenz2.z),
9797
])
98+
99+
100+
# issue #578
101+
102+
let
103+
@variables t x(t) y(t) z(t);
104+
@derivatives D'~t;
105+
eqs = [
106+
D(x) ~ x + y
107+
x ~ y
108+
];
109+
sys = ODESystem(eqs, t, [x], []);
110+
asys = alias_elimination(ModelingToolkit.flatten(sys))
111+
112+
test_equal.(asys.eqs, [D(x) ~ 2x])
113+
test_equal.(asys.observed, [y ~ x])
114+
end

0 commit comments

Comments
 (0)