Skip to content

Commit 2dc0221

Browse files
authored
Merge pull request #1718 from SciML/myb/alias
Make sure `D(-D(x))` gets expanded in alias elimination
2 parents 46778e3 + e4d10cc commit 2dc0221

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/systems/alias_elimination.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ end
3535
alias_elimination(sys) = alias_elimination!(TearingState(sys; quick_cancel = true))
3636
function alias_elimination!(state::TearingState)
3737
sys = state.sys
38+
complete!(state.structure)
3839
ag, mm, updated_diff_vars = alias_eliminate_graph!(state)
3940
ag === nothing && return sys
4041

@@ -52,8 +53,20 @@ function alias_elimination!(state::TearingState)
5253
end
5354

5455
subs = Dict()
56+
# If we encounter y = -D(x), then we need to expand the derivative when
57+
# D(y) appears in the equation, so that D(-D(x)) becomes -D(D(x)).
58+
to_expand = Int[]
59+
diff_to_var = invview(var_to_diff)
5560
for (v, (coeff, alias)) in pairs(ag)
5661
subs[fullvars[v]] = iszero(coeff) ? 0 : coeff * fullvars[alias]
62+
if coeff == -1
63+
# if `alias` is like -D(x)
64+
diff_to_var[alias] === nothing && continue
65+
# if `v` is like y, and D(y) also exists
66+
(dv = var_to_diff[v]) === nothing && continue
67+
# all equations that contains D(y) needs to be expanded.
68+
append!(to_expand, 𝑑neighbors(graph, dv))
69+
end
5770
end
5871

5972
dels = Int[]
@@ -72,11 +85,29 @@ function alias_elimination!(state::TearingState)
7285
end
7386
end
7487
deleteat!(eqs, sort!(dels))
88+
old_to_new = Vector{Int}(undef, length(var_to_diff))
89+
idx = 0
90+
cursor = 1
91+
ndels = length(dels)
92+
for (i, e) in enumerate(old_to_new)
93+
if cursor <= ndels && i == dels[cursor]
94+
cursor += 1
95+
old_to_new[i] = -1
96+
continue
97+
end
98+
idx += 1
99+
old_to_new[i] = idx
100+
end
75101

76102
for (ieq, eq) in enumerate(eqs)
77103
eqs[ieq] = substitute(eq, subs)
78104
end
79105

106+
for old_ieq in to_expand
107+
ieq = old_to_new[old_ieq]
108+
eqs[ieq] = expand_derivatives(eqs[ieq])
109+
end
110+
80111
newstates = []
81112
diff_to_var = invview(var_to_diff)
82113
for j in eachindex(fullvars)

test/odesystem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,3 +849,17 @@ let
849849
D(sys.x) ~ sys.v]
850850
@test isequal(full_equations(sys_simp), true_eqs)
851851
end
852+
853+
let
854+
@variables t
855+
@variables x(t) = 1
856+
@variables y(t) = 1
857+
@parameters pp = -1
858+
der = Differential(t)
859+
@named sys4 = ODESystem([der(x) ~ -y; der(y) ~ 1 - y + x], t)
860+
as = alias_elimination(sys4)
861+
@test length(equations(as)) == 1
862+
@test isequal(equations(as)[1].lhs, -der(der(x)))
863+
# TODO: maybe do not emit x_t
864+
@test_nowarn sys4s = structural_simplify(sys4)
865+
end

0 commit comments

Comments
 (0)