|
| 1 | +""" |
| 2 | + struct CondRewriter |
| 3 | +
|
| 4 | +Callable struct used to transform symbolic conditions into conditions involving discrete |
| 5 | +variables. |
| 6 | +""" |
| 7 | +struct CondRewriter |
| 8 | + """ |
| 9 | + The independent variable which the discrete variables depend on. |
| 10 | + """ |
| 11 | + iv::BasicSymbolic |
| 12 | + """ |
| 13 | + A mapping from a discrete variables to a `NamedTuple` containing the condition |
| 14 | + determining whether the discrete variable needs to be evaluated and the symbolic |
| 15 | + expression the discrete variable represents. The expression is a comparison operation |
| 16 | + such that the LHS of the comparison is used as a rootfinding function, and |
| 17 | + zero-crossings trigger re-evaluation of the condition (if `dependency` is `true`). |
| 18 | + """ |
| 19 | + conditions::Dict{Any, @NamedTuple{dependency, expression}} |
| 20 | +end |
| 21 | + |
| 22 | +function CondRewriter(iv) |
| 23 | + return CondRewriter(iv, Dict()) |
| 24 | +end |
| 25 | + |
| 26 | +""" |
| 27 | +A function which transforms comparison operations of the form `var op var` into |
| 28 | +`var - var op 0`. |
| 29 | +""" |
| 30 | +const COMPARISON_TRANSFORM = unwrap ∘ SymbolicUtils.Rewriters.Chain([ |
| 31 | + (@rule (~a) < (~b) => ~a - ~b < 0), |
| 32 | + (@rule (~a) > (~b) => ~a - ~b > 0), |
| 33 | + (@rule (~a) <= (~b) => ~a - ~b <= 0), |
| 34 | + (@rule (~a) >= (~b) => ~a - ~b >= 0), |
| 35 | +]) |
| 36 | + |
| 37 | +""" |
| 38 | + $(TYPEDSIGNATURES) |
| 39 | +
|
| 40 | +Given a symbolic condition `expr` and the condition `dep` it depends on, update the |
| 41 | +mapping in `cw` and generate a new discrete variable if necessary. |
| 42 | +""" |
| 43 | +function new_cond_sym(cw::CondRewriter, expr, dep) |
| 44 | + # check if the same expression exists in the mapping |
| 45 | + existing_var = findfirst(p -> isequal(p.expression, expr), cw.conditions) |
| 46 | + if existing_var !== nothing |
| 47 | + # cache hit |
| 48 | + (existing_dep, _) = cw.conditions[existing_var] |
| 49 | + # update the dependency condition |
| 50 | + cw.conditions[existing_var] = (dependency=(dep | existing_dep), expression=expr) |
| 51 | + return existing_var |
| 52 | + end |
| 53 | + # generate a new condition variable |
| 54 | + cvar = gensym("cond") |
| 55 | + st = symtype(expr) |
| 56 | + iv = cw.iv |
| 57 | + cv = first(@parameters $(cvar)(iv)::st = true) # TODO: real init |
| 58 | + cw.conditions[cv] = (dependency=dep, expression=expr) |
| 59 | + return cv |
| 60 | +end |
| 61 | + |
| 62 | +""" |
| 63 | +A list of comparison operations. |
| 64 | +""" |
| 65 | +const COMPARISONS = Set([Base.:<, Base.:>, Base.:<=, Base.:>=]) |
| 66 | + |
| 67 | +""" |
| 68 | +Utility function for boolean implication. |
| 69 | +""" |
| 70 | +implies(a, b) = !a & b |
| 71 | + |
| 72 | +""" |
| 73 | + $(TYPEDSIGNATURES) |
| 74 | +
|
| 75 | +Recursively rewrite conditions into discrete variables. `expr` is the condition to rewrite, |
| 76 | +`dep` is a boolean expression/value which determines when the `expr` is to be evaluated. For |
| 77 | +example, if `expr = expr1 | expr2` and `dep = dep1`, then `expr` should only be evaluated if |
| 78 | +`dep1` evaluates to `true`. Recursively, `expr1` should only be evaluated if `dep1` is `true`, |
| 79 | +and `expr2` should only be evaluated if `dep & !expr1`. |
| 80 | +
|
| 81 | +Returns a 3-tuple of the substituted expression, a condition describing when `expr` evaluates |
| 82 | +to `true`, and a condition describing when `expr` evaluates to `false`. |
| 83 | +""" |
| 84 | +function (cw::CondRewriter)(expr, dep) |
| 85 | + # single variable, trivial case |
| 86 | + if issym(expr) || iscall(expr) && issym(operation(expr)) |
| 87 | + return (expr, expr, !expr) |
| 88 | + # literal boolean or integer |
| 89 | + elseif expr isa Bool |
| 90 | + return (expr, expr, !expr) |
| 91 | + elseif expr isa Int |
| 92 | + return (expr, true, true) |
| 93 | + # other singleton symbolic variables |
| 94 | + elseif !iscall(expr) |
| 95 | + @warn "Automatic conversion of if statments to events requires use of a limited conditional grammar; see the documentation. Skipping due to $expr" |
| 96 | + return (expr, true, true) # error case => conservative assumption is that both true and false have to be evaluated |
| 97 | + elseif operation(expr) == Base.:(|) # OR of two conditions |
| 98 | + a, b = arguments(expr) |
| 99 | + (rw_conda, truea, falsea) = cw(a, dep) |
| 100 | + # only evaluate second if first is false |
| 101 | + (rw_condb, trueb, falseb) = cw(b, dep & falsea) |
| 102 | + return (rw_conda | rw_condb, truea | trueb, falsea & falseb) |
| 103 | + |
| 104 | + elseif operation(expr) == Base.:(&) # AND of two conditions |
| 105 | + a, b = arguments(expr) |
| 106 | + (rw_conda, truea, falsea) = cw(a, dep) |
| 107 | + # only evaluate second if first is true |
| 108 | + (rw_condb, trueb, falseb) = cw(b, dep & truea) |
| 109 | + return (rw_conda & rw_condb, truea & trueb, falsea | falseb) |
| 110 | + elseif operation(expr) == ifelse |
| 111 | + c, a, b = arguments(expr) |
| 112 | + (rw_cond, ctrue, cfalse) = cw(c, dep) |
| 113 | + # only evaluate if condition is true |
| 114 | + (rw_conda, truea, falsea) = cw(a, dep & ctrue) |
| 115 | + # only evaluate if condition is false |
| 116 | + (rw_condb, trueb, falseb) = cw(b, dep & cfalse) |
| 117 | + # expression is true if condition is true and THEN branch is true, or condition is false |
| 118 | + # and ELSE branch is true |
| 119 | + # similarly for expression being false |
| 120 | + return (ifelse(rw_cond, rw_conda, rw_condb), implies(ctrue, truea) | implies(cfalse, trueb), implies(ctrue, falsea) | implies(cfalse, falseb)) |
| 121 | + elseif operation(expr) == Base.:(!) # NOT of expression |
| 122 | + (a,) = arguments(expr) |
| 123 | + (rw, ctrue, cfalse) = cw(a, dep) |
| 124 | + return (!rw, cfalse, ctrue) |
| 125 | + elseif operation(expr) in COMPARISONS # comparison operators |
| 126 | + # turn int `var - var op 0` |
| 127 | + expr = COMPARISON_TRANSFORM(expr) |
| 128 | + # a new discrete variable to represent `var - var op 0` |
| 129 | + cv = new_cond_sym(cw, expr, dep) |
| 130 | + return (cv, cv, !cv) |
| 131 | + elseif operation(expr) == (==) |
| 132 | + # we don't touch equality since it's a point discontinuity. It's basically always |
| 133 | + # false for continuous variables. In case it's an equality between discrete |
| 134 | + # quantities, we don't need to transform it. |
| 135 | + return (expr, expr, !expr) |
| 136 | + end |
| 137 | + error("Unsupported expression form in decision variable computation $expr") |
| 138 | +end |
| 139 | + |
| 140 | +""" |
| 141 | + $(TYPEDSIGNATURES) |
| 142 | +
|
| 143 | +Acts as the identity function, and prevents transformation of conditional expressions inside it. Useful |
| 144 | +if specific `ifelse` or other functions with discontinuous derivatives shouldn't be transformed into |
| 145 | +callbacks. |
| 146 | +""" |
| 147 | +no_if_lift(s) = s |
| 148 | +@register_symbolic no_if_lift(s) |
| 149 | + |
| 150 | +""" |
| 151 | + $(TYPEDEF) |
| 152 | +
|
| 153 | +A utility struct to search through an expression specifically for `ifelse` terms, and find |
| 154 | +all variables used in the condition of such terms. The variables are stored in a field of |
| 155 | +the struct. |
| 156 | +""" |
| 157 | +struct VarsUsedInCondition |
| 158 | + """ |
| 159 | + Stores variables used in conditions of `ifelse` statements in the expression. |
| 160 | + """ |
| 161 | + vars::Set{Any} |
| 162 | +end |
| 163 | + |
| 164 | +VarsUsedInCondition() = VarsUsedInCondition(Set()) |
| 165 | + |
| 166 | +function (v::VarsUsedInCondition)(expr) |
| 167 | + expr = Symbolics.unwrap(expr) |
| 168 | + if symbolic_type(expr) == NotSymbolic() |
| 169 | + is_array_of_symbolics(expr) || return |
| 170 | + foreach(v, expr) |
| 171 | + return |
| 172 | + end |
| 173 | + iscall(expr) || return |
| 174 | + op = operation(expr) |
| 175 | + |
| 176 | + # do not search inside no_if_lift to avoid discovering |
| 177 | + # redundant variables |
| 178 | + op == no_if_lift && return |
| 179 | + |
| 180 | + args = arguments(expr) |
| 181 | + if op == ifelse |
| 182 | + cond, branch_a, branch_b = arguments(expr) |
| 183 | + vars!(v.vars, cond) |
| 184 | + v(branch_a) |
| 185 | + v(branch_b) |
| 186 | + end |
| 187 | + foreach(v, args) |
| 188 | + return |
| 189 | +end |
| 190 | + |
| 191 | +""" |
| 192 | + $(TYPEDSIGNATURES) |
| 193 | +
|
| 194 | +Given an expression `expr` which is to be evaluated if `dep` evaluates to `true`, transform |
| 195 | +the conditions of all all `ifelse` statements in `expr` into functions of new discrete |
| 196 | +variables. `cw` is used to store the information relevant to these newly introduced variables. |
| 197 | +""" |
| 198 | +function rewrite_ifs(cw::CondRewriter, expr, dep) |
| 199 | + expr = unwrap(expr) |
| 200 | + if symbolic_type(expr) == NotSymbolic() |
| 201 | + is_array_of_symbolics(expr) || return expr |
| 202 | + return map(expr) do ex |
| 203 | + rewrite_ifs(cw, ex, dep) |
| 204 | + end |
| 205 | + end |
| 206 | + iscall(expr) || return expr |
| 207 | + op = operation(expr) |
| 208 | + # don't recurse into singleton variables or places where the user doesn't want if-lifting |
| 209 | + (issym(op) || op == no_if_lift) && return expr |
| 210 | + args = arguments(expr) |
| 211 | + |
| 212 | + # transform `ifelse` that don't depend on a single symbolic variable. |
| 213 | + if op == ifelse && (!issym(args[1]) || iscall(args[1]) && !issym(operation(args[1]))) |
| 214 | + cond, iftrue, iffalse = args |
| 215 | + (rw_cond, deptrue, depfalse) = cw(cond, dep) |
| 216 | + rw_iftrue = rewrite_ifs(cw, iftrue, deptrue) |
| 217 | + rw_iffalse = rewrite_ifs(cw, iffalse, depfalse) |
| 218 | + return ifelse(unwrap(rw_cond), rw_iftrue, rw_iffalse) |
| 219 | + end |
| 220 | + # recursively rewrite |
| 221 | + return maketerm(typeof(expr), op, map(x -> rewrite_ifs(cw, x, dep), args), metadata(expr)) |
| 222 | +end |
| 223 | + |
| 224 | +""" |
| 225 | + $(TYPEDSIGNATURES) |
| 226 | +
|
| 227 | +Return a modified `expr` where functions with known discontinuities or discontinuous |
| 228 | +derivatives are transformed into `ifelse` statements. Utilizes the discontinuity API |
| 229 | +in Symbolics. See [`Symbolics.rootfunction`](@ref), |
| 230 | +[`Symbolics.left_continuous_function`](@ref), [`Symbolics.right_continuous_function`](@ref). |
| 231 | +""" |
| 232 | +function discontinuities_to_ifelse(expr) |
| 233 | + if symbolic_type(expr) == NotSymbolic() |
| 234 | + is_array_of_symbolics(expr) || return expr |
| 235 | + return map(discontinuities_to_ifelse, expr) |
| 236 | + end |
| 237 | + iscall(expr) || return expr |
| 238 | + op = operation(expr) |
| 239 | + # don't transform inside `no_if_lift` |
| 240 | + (issym(op) || op === no_if_lift) && return expr |
| 241 | + args = arguments(expr) |
| 242 | + args = map(discontinuities_to_ifelse, args) |
| 243 | + # if the operation is a known discontinuity |
| 244 | + if hasmethod(Symbolics.rootfunction, Tuple{typeof(op)}) |
| 245 | + rootfn = Symbolics.rootfunction(op) |
| 246 | + leftfn = Symbolics.left_continuous_function(op) |
| 247 | + rightfn = Symbolics.right_continuous_function(op) |
| 248 | + rootexpr = rootfn(args...) < 0 |
| 249 | + leftexpr = leftfn(args...) |
| 250 | + rightexpr = rightfn(args...) |
| 251 | + return ifelse(rootexpr, leftexpr, rightexpr) |
| 252 | + end |
| 253 | + return maketerm(typeof(expr), op, args, Symbolics.metadata(expr)) |
| 254 | +end |
| 255 | + |
| 256 | +""" |
| 257 | + $(TYPEDSIGNATURES) |
| 258 | +
|
| 259 | +Generate the symbolic condition for discrete variable `sym`, which represents the condition |
| 260 | +of an `ifelse` statement created through [`IfLifting`](@ref). This condition is used to |
| 261 | +trigger a callback which updates the value of the condition appropriately. |
| 262 | +""" |
| 263 | +function generate_condition(cw::CondRewriter, sym) |
| 264 | + (dep, uexpr) = cw.conditions[sym] |
| 265 | + # `uexpr` is a comparison, the LHS is the zero-crossing function |
| 266 | + zero_crossing = arguments(uexpr)[1] |
| 267 | + # if we're meant to evaluate the condition, evaluate it. Otherwise, return `NaN`. |
| 268 | + # the solvers don't treat the transition from a number to NaN or back as a zero-crossing, |
| 269 | + # so it can be used to effectively disable the affect when the condition is not meant to |
| 270 | + # be evaluated. |
| 271 | + return ifelse(dep, arguments(uexpr)[1], NaN) ~ 0 |
| 272 | +end |
| 273 | + |
| 274 | +""" |
| 275 | + $(TYPEDSIGNATURES) |
| 276 | +
|
| 277 | +Generate the affect function for discrete variable `sym` involved in `ifelse` statements that |
| 278 | +are lifted to callbacks using [`IfLifting`](@ref). `syms` is a condition variable introduced |
| 279 | +by `cw`, and is thus a key in `cw.conditions`. `new_cond_vars` is the list of all such new |
| 280 | +condition variables, corresponding to the order of vertices in `new_cond_vars_graph`. |
| 281 | +`new_cond_vars_graph` is a directed graph where edges denote the condition variables involved |
| 282 | +in the dependency expression of the source vertex. |
| 283 | +""" |
| 284 | +function generate_affect(cw::CondRewriter, sym, new_cond_vars, new_cond_vars_graph) |
| 285 | + sym_idx = findfirst(isequal(sym), new_cond_vars) |
| 286 | + if sym_idx === nothing |
| 287 | + throw(ArgumentError("Expected variable $sym to be a condition variable in $new_cond_vars.")) |
| 288 | + end |
| 289 | + # use reverse direction of edges because instead of finding the variables it depends |
| 290 | + # on, we want the variables that depend on it |
| 291 | + parents = bfs_parents(new_cond_vars_graph, sym_idx; dir = :in) |
| 292 | + cond_vars_to_update = [new_cond_vars[i] for i in eachindex(parents) if !iszero(parents[i])] |
| 293 | + update_syms = Symbol.(cond_vars_to_update) |
| 294 | + update_exprs = [last(cw.conditions[sym]) for sym in cond_vars_to_update] |
| 295 | + return ImperativeAffect(modified=NamedTuple{(update_syms...,)}(cond_vars_to_update), observed=NamedTuple{(update_syms...,)}(update_exprs), skip_checks=true) do x, o, c, i |
| 296 | + x .= o |
| 297 | + end |
| 298 | +end |
| 299 | + |
| 300 | +""" |
| 301 | +If lifting converts (nested) if statements into a series of continous events + a logically equivalent if statement + parameters. |
| 302 | +
|
| 303 | +Lifting proceeds through the following process: |
| 304 | +* rewrite comparisons to be of the form eqn [op] 0; subtract the RHS from the LHS |
| 305 | +* replace comparisons with generated parameters; for each comparison eqn [op] 0, generate an event (dependent on op) that sets the parameter |
| 306 | +""" |
| 307 | +function IfLifting(sys::ODESystem) |
| 308 | + cw = CondRewriter(get_iv(sys)) |
| 309 | + |
| 310 | + eqs = copy(equations(sys)) |
| 311 | + obs = copy(observed(sys)) |
| 312 | + |
| 313 | + # get variables used by `eqs` |
| 314 | + syms = vars(eqs) |
| 315 | + # get observed equations used by `eqs` |
| 316 | + obs_idxs = observed_equations_used_by(sys, eqs; involved_vars = syms) |
| 317 | + # and the variables used in those equations |
| 318 | + for i in obs_idxs |
| 319 | + vars!(syms, obs[i]) |
| 320 | + end |
| 321 | + |
| 322 | + # get all integral variables used in conditions |
| 323 | + # this is used when performing the transformation on observed equations |
| 324 | + # since they are transformed differently depending on whether they are |
| 325 | + # discrete variables involved in a condition or not |
| 326 | + condition_vars = Set() |
| 327 | + # searcher struct |
| 328 | + # we can use the same one since it avoids iterating over duplicates |
| 329 | + vars_in_condition! = VarsUsedInCondition() |
| 330 | + for i in eachindex(eqs) |
| 331 | + eq = eqs[i] |
| 332 | + vars_in_condition!(eq.rhs) |
| 333 | + # also transform the equation |
| 334 | + eqs[i] = eq.lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true) |
| 335 | + end |
| 336 | + # also search through relevant observed equations |
| 337 | + for i in obs_idxs |
| 338 | + vars_in_condition!(obs[i].rhs) |
| 339 | + end |
| 340 | + # add to `condition_vars` after filtering out differential, parameter, independent and |
| 341 | + # non-integral variables |
| 342 | + for v in vars_in_condition!.vars |
| 343 | + v = unwrap(v) |
| 344 | + stype = symtype(v) |
| 345 | + if isdifferential(v) || isparameter(v) || isequal(v, get_iv(sys)) |
| 346 | + continue |
| 347 | + end |
| 348 | + stype <: Union{Integer, AbstractArray{Integer}} && push!(condition_vars, v) |
| 349 | + end |
| 350 | + # transform observed equations |
| 351 | + for i in obs_idxs |
| 352 | + obs[i] = if obs[i].lhs in condition_vars |
| 353 | + obs[i].lhs ~ first(cw(obs[i].rhs, true)) |
| 354 | + else |
| 355 | + obs[i].lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true) |
| 356 | + end |
| 357 | + end |
| 358 | + |
| 359 | + # get directed graph where nodes are the new condition variables and edges from each |
| 360 | + # node denote the condition variables used in it's dependency expression |
| 361 | + |
| 362 | + # so we have an ordering for the vertices |
| 363 | + new_cond_vars = collect(keys(cw.conditions)) |
| 364 | + # "observed" equations |
| 365 | + new_cond_dep_eqs = [v ~ cw.conditions[v] for v in new_cond_vars] |
| 366 | + # construct the graph as a `DiCMOBiGraph` |
| 367 | + new_cond_vars_graph = observed_dependency_graph(new_cond_dep_eqs) |
| 368 | + |
| 369 | + new_callbacks = continuous_events(sys) |
| 370 | + new_defaults = defaults(sys) |
| 371 | + new_ps = parameters(sys) |
| 372 | + |
| 373 | + for var in new_cond_vars |
| 374 | + condition = generate_condition(cw, var) |
| 375 | + affect = generate_affect(cw, var, new_cond_vars, new_cond_vars_graph) |
| 376 | + cb = SymbolicContinuousCallback([condition], affect; affect_neg=affect, initialize=affect, rootfind=SciMLBase.RightRootFind) |
| 377 | + |
| 378 | + push!(new_callbacks, cb) |
| 379 | + new_defaults[var] = getdefault(var) |
| 380 | + push!(new_ps, var) |
| 381 | + end |
| 382 | + |
| 383 | + @set! sys.defaults = new_defaults |
| 384 | + @set! sys.eqs = eqs |
| 385 | + # do not need to topsort because we didn't modify the order |
| 386 | + @set! sys.observed = obs |
| 387 | + @set! sys.continuous_events = new_callbacks |
| 388 | + @set! sys.ps = new_ps |
| 389 | + return sys |
| 390 | +end |
| 391 | + |
0 commit comments