Skip to content

Commit 5cb05a5

Browse files
feat: add Symbolics.fast_substitute for affects
1 parent b1d4592 commit 5cb05a5

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

src/systems/callbacks.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...)
2525
end
2626
SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...)
2727

28+
function Symbolics.fast_substitute(aff::SymbolicAffect, rules)
29+
substituter = Base.Fix2(fast_substitute, rules)
30+
SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs),
31+
map(substituter, aff.discrete_parameters))
32+
end
33+
2834
struct AffectSystem
2935
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
3036
system::AbstractSystem
@@ -36,6 +42,19 @@ struct AffectSystem
3642
discretes::Vector
3743
end
3844

45+
function Symbolics.fast_substitute(aff::AffectSystem, rules)
46+
substituter = Base.Fix2(fast_substitute, rules)
47+
sys = aff.system
48+
@set! sys.eqs = map(substituter, get_eqs(sys))
49+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
50+
@set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)])
51+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
52+
@set! sys.unknowns = map(substituter, get_unknowns(sys))
53+
@set! sys.ps = map(substituter, get_ps(sys))
54+
AffectSystem(sys, map(substituter, aff.unknowns),
55+
map(substituter, aff.parameters), map(substituter, aff.discretes))
56+
end
57+
3958
function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
4059
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
4160
discrete_parameters = spec.discrete_parameters, kwargs...)

src/systems/imperative_affect.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ function ImperativeAffect(; f, kwargs...)
6767
ImperativeAffect(f; kwargs...)
6868
end
6969

70+
function Symbolics.fast_substitute(aff::ImperativeAffect, rules)
71+
substituter = Base.Fix2(fast_substitute, rules)
72+
ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms,
73+
map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks)
74+
end
75+
7076
function Base.show(io::IO, mfa::ImperativeAffect)
7177
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
7278
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")

0 commit comments

Comments
 (0)