Skip to content

Commit d5b773a

Browse files
feat: implement IfLifting structural simplification pass
1 parent 63d2658 commit d5b773a

File tree

2 files changed

+392
-0
lines changed

2 files changed

+392
-0
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ include("discretedomain.jl")
181181
include("systems/systemstructure.jl")
182182
include("systems/clock_inference.jl")
183183
include("systems/systems.jl")
184+
include("systems/if_lifting.jl")
184185

185186
include("debugging.jl")
186187
include("systems/alias_elimination.jl")

src/systems/if_lifting.jl

Lines changed: 391 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
"""
2+
struct CondRewriter
3+
4+
Callable struct used to transform symbolic conditions into conditions involving discrete
5+
variables.
6+
"""
7+
struct CondRewriter
8+
"""
9+
The independent variable which the discrete variables depend on.
10+
"""
11+
iv::BasicSymbolic
12+
"""
13+
A mapping from a discrete variables to a `NamedTuple` containing the condition
14+
determining whether the discrete variable needs to be evaluated and the symbolic
15+
expression the discrete variable represents. The expression is a comparison operation
16+
such that the LHS of the comparison is used as a rootfinding function, and
17+
zero-crossings trigger re-evaluation of the condition (if `dependency` is `true`).
18+
"""
19+
conditions::Dict{Any, @NamedTuple{dependency, expression}}
20+
end
21+
22+
function CondRewriter(iv)
23+
return CondRewriter(iv, Dict())
24+
end
25+
26+
"""
27+
A function which transforms comparison operations of the form `var op var` into
28+
`var - var op 0`.
29+
"""
30+
const COMPARISON_TRANSFORM = unwrap SymbolicUtils.Rewriters.Chain([
31+
(@rule (~a) < (~b) => ~a - ~b < 0),
32+
(@rule (~a) > (~b) => ~a - ~b > 0),
33+
(@rule (~a) <= (~b) => ~a - ~b <= 0),
34+
(@rule (~a) >= (~b) => ~a - ~b >= 0),
35+
])
36+
37+
"""
38+
$(TYPEDSIGNATURES)
39+
40+
Given a symbolic condition `expr` and the condition `dep` it depends on, update the
41+
mapping in `cw` and generate a new discrete variable if necessary.
42+
"""
43+
function new_cond_sym(cw::CondRewriter, expr, dep)
44+
# check if the same expression exists in the mapping
45+
existing_var = findfirst(p -> isequal(p.expression, expr), cw.conditions)
46+
if existing_var !== nothing
47+
# cache hit
48+
(existing_dep, _) = cw.conditions[existing_var]
49+
# update the dependency condition
50+
cw.conditions[existing_var] = (dependency=(dep | existing_dep), expression=expr)
51+
return existing_var
52+
end
53+
# generate a new condition variable
54+
cvar = gensym("cond")
55+
st = symtype(expr)
56+
iv = cw.iv
57+
cv = first(@parameters $(cvar)(iv)::st = true) # TODO: real init
58+
cw.conditions[cv] = (dependency=dep, expression=expr)
59+
return cv
60+
end
61+
62+
"""
63+
A list of comparison operations.
64+
"""
65+
const COMPARISONS = Set([Base.:<, Base.:>, Base.:<=, Base.:>=])
66+
67+
"""
68+
Utility function for boolean implication.
69+
"""
70+
implies(a, b) = !a & b
71+
72+
"""
73+
$(TYPEDSIGNATURES)
74+
75+
Recursively rewrite conditions into discrete variables. `expr` is the condition to rewrite,
76+
`dep` is a boolean expression/value which determines when the `expr` is to be evaluated. For
77+
example, if `expr = expr1 | expr2` and `dep = dep1`, then `expr` should only be evaluated if
78+
`dep1` evaluates to `true`. Recursively, `expr1` should only be evaluated if `dep1` is `true`,
79+
and `expr2` should only be evaluated if `dep & !expr1`.
80+
81+
Returns a 3-tuple of the substituted expression, a condition describing when `expr` evaluates
82+
to `true`, and a condition describing when `expr` evaluates to `false`.
83+
"""
84+
function (cw::CondRewriter)(expr, dep)
85+
# single variable, trivial case
86+
if issym(expr) || iscall(expr) && issym(operation(expr))
87+
return (expr, expr, !expr)
88+
# literal boolean or integer
89+
elseif expr isa Bool
90+
return (expr, expr, !expr)
91+
elseif expr isa Int
92+
return (expr, true, true)
93+
# other singleton symbolic variables
94+
elseif !iscall(expr)
95+
@warn "Automatic conversion of if statments to events requires use of a limited conditional grammar; see the documentation. Skipping due to $expr"
96+
return (expr, true, true) # error case => conservative assumption is that both true and false have to be evaluated
97+
elseif operation(expr) == Base.:(|) # OR of two conditions
98+
a, b = arguments(expr)
99+
(rw_conda, truea, falsea) = cw(a, dep)
100+
# only evaluate second if first is false
101+
(rw_condb, trueb, falseb) = cw(b, dep & falsea)
102+
return (rw_conda | rw_condb, truea | trueb, falsea & falseb)
103+
104+
elseif operation(expr) == Base.:(&) # AND of two conditions
105+
a, b = arguments(expr)
106+
(rw_conda, truea, falsea) = cw(a, dep)
107+
# only evaluate second if first is true
108+
(rw_condb, trueb, falseb) = cw(b, dep & truea)
109+
return (rw_conda & rw_condb, truea & trueb, falsea | falseb)
110+
elseif operation(expr) == ifelse
111+
c, a, b = arguments(expr)
112+
(rw_cond, ctrue, cfalse) = cw(c, dep)
113+
# only evaluate if condition is true
114+
(rw_conda, truea, falsea) = cw(a, dep & ctrue)
115+
# only evaluate if condition is false
116+
(rw_condb, trueb, falseb) = cw(b, dep & cfalse)
117+
# expression is true if condition is true and THEN branch is true, or condition is false
118+
# and ELSE branch is true
119+
# similarly for expression being false
120+
return (ifelse(rw_cond, rw_conda, rw_condb), implies(ctrue, truea) | implies(cfalse, trueb), implies(ctrue, falsea) | implies(cfalse, falseb))
121+
elseif operation(expr) == Base.:(!) # NOT of expression
122+
(a,) = arguments(expr)
123+
(rw, ctrue, cfalse) = cw(a, dep)
124+
return (!rw, cfalse, ctrue)
125+
elseif operation(expr) in COMPARISONS # comparison operators
126+
# turn int `var - var op 0`
127+
expr = COMPARISON_TRANSFORM(expr)
128+
# a new discrete variable to represent `var - var op 0`
129+
cv = new_cond_sym(cw, expr, dep)
130+
return (cv, cv, !cv)
131+
elseif operation(expr) == (==)
132+
# we don't touch equality since it's a point discontinuity. It's basically always
133+
# false for continuous variables. In case it's an equality between discrete
134+
# quantities, we don't need to transform it.
135+
return (expr, expr, !expr)
136+
end
137+
error("Unsupported expression form in decision variable computation $expr")
138+
end
139+
140+
"""
141+
$(TYPEDSIGNATURES)
142+
143+
Acts as the identity function, and prevents transformation of conditional expressions inside it. Useful
144+
if specific `ifelse` or other functions with discontinuous derivatives shouldn't be transformed into
145+
callbacks.
146+
"""
147+
no_if_lift(s) = s
148+
@register_symbolic no_if_lift(s)
149+
150+
"""
151+
$(TYPEDEF)
152+
153+
A utility struct to search through an expression specifically for `ifelse` terms, and find
154+
all variables used in the condition of such terms. The variables are stored in a field of
155+
the struct.
156+
"""
157+
struct VarsUsedInCondition
158+
"""
159+
Stores variables used in conditions of `ifelse` statements in the expression.
160+
"""
161+
vars::Set{Any}
162+
end
163+
164+
VarsUsedInCondition() = VarsUsedInCondition(Set())
165+
166+
function (v::VarsUsedInCondition)(expr)
167+
expr = Symbolics.unwrap(expr)
168+
if symbolic_type(expr) == NotSymbolic()
169+
is_array_of_symbolics(expr) || return
170+
foreach(v, expr)
171+
return
172+
end
173+
iscall(expr) || return
174+
op = operation(expr)
175+
176+
# do not search inside no_if_lift to avoid discovering
177+
# redundant variables
178+
op == no_if_lift && return
179+
180+
args = arguments(expr)
181+
if op == ifelse
182+
cond, branch_a, branch_b = arguments(expr)
183+
vars!(v.vars, cond)
184+
v(branch_a)
185+
v(branch_b)
186+
end
187+
foreach(v, args)
188+
return
189+
end
190+
191+
"""
192+
$(TYPEDSIGNATURES)
193+
194+
Given an expression `expr` which is to be evaluated if `dep` evaluates to `true`, transform
195+
the conditions of all all `ifelse` statements in `expr` into functions of new discrete
196+
variables. `cw` is used to store the information relevant to these newly introduced variables.
197+
"""
198+
function rewrite_ifs(cw::CondRewriter, expr, dep)
199+
expr = unwrap(expr)
200+
if symbolic_type(expr) == NotSymbolic()
201+
is_array_of_symbolics(expr) || return expr
202+
return map(expr) do ex
203+
rewrite_ifs(cw, ex, dep)
204+
end
205+
end
206+
iscall(expr) || return expr
207+
op = operation(expr)
208+
# don't recurse into singleton variables or places where the user doesn't want if-lifting
209+
(issym(op) || op == no_if_lift) && return expr
210+
args = arguments(expr)
211+
212+
# transform `ifelse` that don't depend on a single symbolic variable.
213+
if op == ifelse && (!issym(args[1]) || iscall(args[1]) && !issym(operation(args[1])))
214+
cond, iftrue, iffalse = args
215+
(rw_cond, deptrue, depfalse) = cw(cond, dep)
216+
rw_iftrue = rewrite_ifs(cw, iftrue, deptrue)
217+
rw_iffalse = rewrite_ifs(cw, iffalse, depfalse)
218+
return ifelse(unwrap(rw_cond), rw_iftrue, rw_iffalse)
219+
end
220+
# recursively rewrite
221+
return maketerm(typeof(expr), op, map(x -> rewrite_ifs(cw, x, dep), args), metadata(expr))
222+
end
223+
224+
"""
225+
$(TYPEDSIGNATURES)
226+
227+
Return a modified `expr` where functions with known discontinuities or discontinuous
228+
derivatives are transformed into `ifelse` statements. Utilizes the discontinuity API
229+
in Symbolics. See [`Symbolics.rootfunction`](@ref),
230+
[`Symbolics.left_continuous_function`](@ref), [`Symbolics.right_continuous_function`](@ref).
231+
"""
232+
function discontinuities_to_ifelse(expr)
233+
if symbolic_type(expr) == NotSymbolic()
234+
is_array_of_symbolics(expr) || return expr
235+
return map(discontinuities_to_ifelse, expr)
236+
end
237+
iscall(expr) || return expr
238+
op = operation(expr)
239+
# don't transform inside `no_if_lift`
240+
(issym(op) || op === no_if_lift) && return expr
241+
args = arguments(expr)
242+
args = map(discontinuities_to_ifelse, args)
243+
# if the operation is a known discontinuity
244+
if hasmethod(Symbolics.rootfunction, Tuple{typeof(op)})
245+
rootfn = Symbolics.rootfunction(op)
246+
leftfn = Symbolics.left_continuous_function(op)
247+
rightfn = Symbolics.right_continuous_function(op)
248+
rootexpr = rootfn(args...) < 0
249+
leftexpr = leftfn(args...)
250+
rightexpr = rightfn(args...)
251+
return ifelse(rootexpr, leftexpr, rightexpr)
252+
end
253+
return maketerm(typeof(expr), op, args, Symbolics.metadata(expr))
254+
end
255+
256+
"""
257+
$(TYPEDSIGNATURES)
258+
259+
Generate the symbolic condition for discrete variable `sym`, which represents the condition
260+
of an `ifelse` statement created through [`IfLifting`](@ref). This condition is used to
261+
trigger a callback which updates the value of the condition appropriately.
262+
"""
263+
function generate_condition(cw::CondRewriter, sym)
264+
(dep, uexpr) = cw.conditions[sym]
265+
# `uexpr` is a comparison, the LHS is the zero-crossing function
266+
zero_crossing = arguments(uexpr)[1]
267+
# if we're meant to evaluate the condition, evaluate it. Otherwise, return `NaN`.
268+
# the solvers don't treat the transition from a number to NaN or back as a zero-crossing,
269+
# so it can be used to effectively disable the affect when the condition is not meant to
270+
# be evaluated.
271+
return ifelse(dep, arguments(uexpr)[1], NaN) ~ 0
272+
end
273+
274+
"""
275+
$(TYPEDSIGNATURES)
276+
277+
Generate the affect function for discrete variable `sym` involved in `ifelse` statements that
278+
are lifted to callbacks using [`IfLifting`](@ref). `syms` is a condition variable introduced
279+
by `cw`, and is thus a key in `cw.conditions`. `new_cond_vars` is the list of all such new
280+
condition variables, corresponding to the order of vertices in `new_cond_vars_graph`.
281+
`new_cond_vars_graph` is a directed graph where edges denote the condition variables involved
282+
in the dependency expression of the source vertex.
283+
"""
284+
function generate_affect(cw::CondRewriter, sym, new_cond_vars, new_cond_vars_graph)
285+
sym_idx = findfirst(isequal(sym), new_cond_vars)
286+
if sym_idx === nothing
287+
throw(ArgumentError("Expected variable $sym to be a condition variable in $new_cond_vars."))
288+
end
289+
# use reverse direction of edges because instead of finding the variables it depends
290+
# on, we want the variables that depend on it
291+
parents = bfs_parents(new_cond_vars_graph, sym_idx; dir = :in)
292+
cond_vars_to_update = [new_cond_vars[i] for i in eachindex(parents) if !iszero(parents[i])]
293+
update_syms = Symbol.(cond_vars_to_update)
294+
update_exprs = [last(cw.conditions[sym]) for sym in cond_vars_to_update]
295+
return ImperativeAffect(modified=NamedTuple{(update_syms...,)}(cond_vars_to_update), observed=NamedTuple{(update_syms...,)}(update_exprs), skip_checks=true) do x, o, c, i
296+
x .= o
297+
end
298+
end
299+
300+
"""
301+
If lifting converts (nested) if statements into a series of continous events + a logically equivalent if statement + parameters.
302+
303+
Lifting proceeds through the following process:
304+
* rewrite comparisons to be of the form eqn [op] 0; subtract the RHS from the LHS
305+
* replace comparisons with generated parameters; for each comparison eqn [op] 0, generate an event (dependent on op) that sets the parameter
306+
"""
307+
function IfLifting(sys::ODESystem)
308+
cw = CondRewriter(get_iv(sys))
309+
310+
eqs = copy(equations(sys))
311+
obs = copy(observed(sys))
312+
313+
# get variables used by `eqs`
314+
syms = vars(eqs)
315+
# get observed equations used by `eqs`
316+
obs_idxs = observed_equations_used_by(sys, eqs; involved_vars = syms)
317+
# and the variables used in those equations
318+
for i in obs_idxs
319+
vars!(syms, obs[i])
320+
end
321+
322+
# get all integral variables used in conditions
323+
# this is used when performing the transformation on observed equations
324+
# since they are transformed differently depending on whether they are
325+
# discrete variables involved in a condition or not
326+
condition_vars = Set()
327+
# searcher struct
328+
# we can use the same one since it avoids iterating over duplicates
329+
vars_in_condition! = VarsUsedInCondition()
330+
for i in eachindex(eqs)
331+
eq = eqs[i]
332+
vars_in_condition!(eq.rhs)
333+
# also transform the equation
334+
eqs[i] = eq.lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true)
335+
end
336+
# also search through relevant observed equations
337+
for i in obs_idxs
338+
vars_in_condition!(obs[i].rhs)
339+
end
340+
# add to `condition_vars` after filtering out differential, parameter, independent and
341+
# non-integral variables
342+
for v in vars_in_condition!.vars
343+
v = unwrap(v)
344+
stype = symtype(v)
345+
if isdifferential(v) || isparameter(v) || isequal(v, get_iv(sys))
346+
continue
347+
end
348+
stype <: Union{Integer, AbstractArray{Integer}} && push!(condition_vars, v)
349+
end
350+
# transform observed equations
351+
for i in obs_idxs
352+
obs[i] = if obs[i].lhs in condition_vars
353+
obs[i].lhs ~ first(cw(obs[i].rhs, true))
354+
else
355+
obs[i].lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true)
356+
end
357+
end
358+
359+
# get directed graph where nodes are the new condition variables and edges from each
360+
# node denote the condition variables used in it's dependency expression
361+
362+
# so we have an ordering for the vertices
363+
new_cond_vars = collect(keys(cw.conditions))
364+
# "observed" equations
365+
new_cond_dep_eqs = [v ~ cw.conditions[v] for v in new_cond_vars]
366+
# construct the graph as a `DiCMOBiGraph`
367+
new_cond_vars_graph = observed_dependency_graph(new_cond_dep_eqs)
368+
369+
new_callbacks = continuous_events(sys)
370+
new_defaults = defaults(sys)
371+
new_ps = parameters(sys)
372+
373+
for var in new_cond_vars
374+
condition = generate_condition(cw, var)
375+
affect = generate_affect(cw, var, new_cond_vars, new_cond_vars_graph)
376+
cb = SymbolicContinuousCallback([condition], affect; affect_neg=affect, initialize=affect, rootfind=SciMLBase.RightRootFind)
377+
378+
push!(new_callbacks, cb)
379+
new_defaults[var] = getdefault(var)
380+
push!(new_ps, var)
381+
end
382+
383+
@set! sys.defaults = new_defaults
384+
@set! sys.eqs = eqs
385+
# do not need to topsort because we didn't modify the order
386+
@set! sys.observed = obs
387+
@set! sys.continuous_events = new_callbacks
388+
@set! sys.ps = new_ps
389+
return sys
390+
end
391+

0 commit comments

Comments
 (0)