Skip to content

Commit a916c15

Browse files
committed
Don't symbolic substitute in tearing
1 parent 12b2efe commit a916c15

File tree

5 files changed

+98
-52
lines changed

5 files changed

+98
-52
lines changed

src/bipartite_graph.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ function Graphs.incidence_matrix(g::BipartiteGraph, val=true)
397397
S = sparse(I, J, val, nsrcs(g), ndsts(g))
398398
end
399399

400-
401400
"""
402401
struct DiCMOBiGraph
403402

src/structural_transformation/symbolics_tearing.jl

Lines changed: 89 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,114 @@
1+
"""
2+
uneven_invmap(n::Int, list)
3+
4+
returns an uneven inv map with length `n`.
5+
"""
6+
function uneven_invmap(n::Int, list)
7+
rename = zero(Int, n)
8+
for (i, v) in enumerate(list)
9+
rename[v] = i
10+
end
11+
return rename
12+
end
13+
14+
# N.B. assumes `slist` and `dlist` are unique
15+
function substitution_graph(graph, slist, dlist, var_eq_matching)
16+
ns = length(slist)
17+
nd = length(dlist)
18+
ns == nd || error("internal error")
19+
newgraph = BipartiteGraph(ns, nd)
20+
erename = uneven_invmap(nsrc(graph), slist)
21+
vrename = uneven_invmap(ndst(graph), dlist)
22+
for e in 𝑠vertices(graph)
23+
ie = erename[e]
24+
ie == 0 && continue
25+
for v in 𝑠neighbors(graph, e)
26+
iv = vrename[v]
27+
iv == 0 && continue
28+
add_edge!(newgraph, ie, iv)
29+
end
30+
end
31+
32+
newmatching = zero(slist)
33+
for (v, e) in enumerate(var_eq_matching)
34+
iv = vrename[v]
35+
ie = erename[e]
36+
iv == 0 && continue
37+
ie == 0 && error("internal error")
38+
newmatching[iv] = ie
39+
end
40+
41+
return newgraph, newmatching
42+
end
43+
144
function tearing_sub(expr, dict, s)
245
expr = ModelingToolkit.fixpoint_sub(expr, dict)
346
s ? simplify(expr) : expr
447
end
548

49+
function tearing_substitution(sys::AbstractSystem; simplify=false)
50+
(has_substitutions(sys) && !isnothing(get_substitutions(sys))) || return sys
51+
subs = get_substitutions(sys)
52+
neweqs = map(equations(sys)) do eq
53+
if isdiffeq(eq)
54+
return eq.lhs ~ tearing_sub(eq.rhs, solved, simplify)
55+
else
56+
if !(eq.lhs isa Number && eq.lhs == 0)
57+
eq = 0 ~ eq.rhs - eq.lhs
58+
end
59+
rhs = tearing_sub(eq.rhs, solved, simplify)
60+
if rhs isa Symbolic
61+
return 0 ~ rhs
62+
else # a number
63+
error("tearing failled because the system is singular")
64+
end
65+
end
66+
eq
67+
end
68+
@set! sys.eqs = neweqs
69+
end
70+
71+
function solve_equation(eq, var, simplify)
72+
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
73+
occursin(var, rhs) && error("solving $rhs for [$var] failed")
74+
var ~ rhs
75+
end
76+
77+
function normalize_equation(eq)
78+
if !isdiffeq(eq)
79+
if !(eq.lhs isa Number && eq.lhs == 0)
80+
eq = 0 ~ eq.rhs - eq.lhs
81+
end
82+
end
83+
eq
84+
end
85+
686
function tearing_reassemble(sys, var_eq_matching; simplify=false)
787
s = structure(sys)
888
@unpack fullvars, solvable_graph, graph = s
989

1090
eqs = equations(sys)
1191

1292
### extract partition information
13-
function solve_equation(ieq, iv)
14-
var = fullvars[iv]
15-
eq = eqs[ieq]
16-
rhs = value(solve_for(eq, var; simplify=simplify, check=false))
17-
18-
if var in vars(rhs)
19-
# Usually we should be done here, but if we don't simplify we can get in
20-
# trouble, so try our best to still solve for rhs
21-
if !simplify
22-
rhs = SymbolicUtils.polynormalize(rhs)
23-
end
24-
25-
# Since we know `eq` is linear wrt `var`, so the round off must be a
26-
# linear term. We can correct the round off error by a linear
27-
# correction.
28-
rhs -= expand_derivatives(Differential(var)(rhs))*var
29-
(var in vars(rhs)) && throw(EquationSolveErrors(eq, var, rhs))
30-
end
31-
var => rhs
32-
end
3393
is_solvable(eq, iv) = eq !== unassigned && BipartiteEdge(eq, iv) in solvable_graph
3494

3595
solved_equations = Int[]
3696
solved_variables = Int[]
3797

3898
# Solve solvable equations
39-
for (iv, ieq) in enumerate(var_eq_matching);
40-
is_solvable(ieq, iv) || continue
99+
for (iv, ieq) in enumerate(var_eq_matching)
100+
#is_solvable(ieq, iv) || continue
101+
is_solvable(ieq, iv) || error("unreachable reached")
41102
push!(solved_equations, ieq); push!(solved_variables, iv)
42103
end
43-
44-
solved = Dict(solve_equation(ieq, iv) for (ieq, iv) in zip(solved_equations, solved_variables))
45-
obseqs = [var ~ rhs for (var, rhs) in solved]
104+
subgraph, submatching = substitution_graph(graph, slist, dlist, var_eq_matching)
105+
toporder = topological_sort_by_dfs(DiCMOBiGraph{true}(subgraph, submatching))
106+
substitutions = [solve_equation(eqs[solved_equations[i]], fullvars[solved_variables[i]], simplify) for i in toporder]
46107

47108
# Rewrite remaining equations in terms of solved variables
48-
function substitute_equation(ieq)
49-
eq = eqs[ieq]
50-
if isdiffeq(eq)
51-
return eq.lhs ~ tearing_sub(eq.rhs, solved, simplify)
52-
else
53-
if !(eq.lhs isa Number && eq.lhs == 0)
54-
eq = 0 ~ eq.rhs - eq.lhs
55-
end
56-
rhs = tearing_sub(eq.rhs, solved, simplify)
57-
if rhs isa Symbolic
58-
return 0 ~ rhs
59-
else # a number
60-
if abs(rhs) > 100eps(float(rhs))
61-
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
62-
end
63-
return nothing
64-
end
65-
end
66-
end
67109

68-
neweqs = Any[substitute_equation(ieq) for ieq in 1:length(eqs) if !(ieq in solved_equations)]
69-
filter!(!isnothing, neweqs)
110+
solved_eq_set = BitSet(solved_equations)
111+
neweqs = Equation[normalize_equation(eqs[ieq]) for ieq in 1:length(eqs) if !(ieq in solved_eq_set)]
70112

71113
# Contract the vertices in the structure graph to make the structure match
72114
# the new reality of the system we've just created.
@@ -84,7 +126,7 @@ function tearing_reassemble(sys, var_eq_matching; simplify=false)
84126
@set! sys.structure = s
85127
@set! sys.eqs = neweqs
86128
@set! sys.states = [s.fullvars[idx] for idx in 1:length(s.fullvars) if !isdervar(s, idx)]
87-
@set! sys.observed = [observed(sys); obseqs]
129+
@set! sys.substitutions = substitutions
88130
return sys
89131
end
90132

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ for prop in [
222222
:connector_type
223223
:connections
224224
:preface
225+
:substitutions
225226
]
226227
fname1 = Symbol(:get_, prop)
227228
fname2 = Symbol(:has_, prop)

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,20 @@ struct ODESystem <: AbstractODESystem
9696
The integrator will use root finding to guarantee that it steps at each zero crossing.
9797
"""
9898
continuous_events::Vector{SymbolicContinuousCallback}
99+
"""
100+
substitutions: substitutions generated by tearing.
101+
"""
102+
substitutions::Any
99103

100-
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, connections, preface, events; checks::Bool = true)
104+
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, connections, preface, events, substitutions=nothing; checks::Bool=true)
101105
if checks
102106
check_variables(dvs,iv)
103107
check_parameters(ps,iv)
104108
check_equations(deqs,iv)
105109
check_equations(equations(events),iv)
106110
all_dimensionless([dvs;ps;iv]) || check_units(deqs)
107111
end
108-
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, connections, preface, events)
112+
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connector_type, connections, preface, events, substitutions)
109113
end
110114
end
111115

@@ -153,7 +157,7 @@ function ODESystem(
153157
throw(ArgumentError("System names must be unique."))
154158
end
155159
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
156-
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connector_type, nothing, preface, cont_callbacks, checks = checks)
160+
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connector_type, nothing, preface, cont_callbacks, checks=checks)
157161
end
158162

159163
function ODESystem(eqs, iv=nothing; kwargs...)

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
106106
defaults=_merge(Dict(default_u0), Dict(default_p)),
107107
name=nothing,
108108
connector_type=nothing,
109-
checks = true,
109+
checks=true,
110110
)
111111
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
112112
deqs = scalarize(deqs)

0 commit comments

Comments
 (0)