Skip to content

Commit 8160127

Browse files
committed
refactor: refactor affect codegen
1 parent 5b97ebb commit 8160127

File tree

3 files changed

+147
-29
lines changed

3 files changed

+147
-29
lines changed

src/ModelingToolkit.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ include("systems/model_parsing.jl")
154154
include("systems/connectors.jl")
155155
include("systems/analysis_points.jl")
156156
include("systems/imperative_affect.jl")
157-
include("systems/callbacks.jl")
158157
include("systems/codegen_utils.jl")
159158
include("systems/problem_utils.jl")
160159
include("linearization.jl")
@@ -164,19 +163,20 @@ include("systems/optimization/optimizationsystem.jl")
164163
include("systems/optimization/modelingtoolkitize.jl")
165164

166165
include("systems/nonlinear/nonlinearsystem.jl")
167-
include("systems/nonlinear/homotopy_continuation.jl")
166+
include("systems/discrete_system/discrete_system.jl")
167+
include("systems/discrete_system/implicit_discrete_system.jl")
168+
include("systems/callbacks.jl")
169+
168170
include("systems/diffeqs/odesystem.jl")
169171
include("systems/diffeqs/sdesystem.jl")
170172
include("systems/diffeqs/abstractodesystem.jl")
173+
include("systems/nonlinear/homotopy_continuation.jl")
171174
include("systems/nonlinear/modelingtoolkitize.jl")
172175
include("systems/nonlinear/initializesystem.jl")
173176
include("systems/diffeqs/first_order_transform.jl")
174177
include("systems/diffeqs/modelingtoolkitize.jl")
175178
include("systems/diffeqs/basic_transformations.jl")
176179

177-
include("systems/discrete_system/discrete_system.jl")
178-
include("systems/discrete_system/implicit_discrete_system.jl")
179-
180180
include("systems/jumps/jumpsystem.jl")
181181

182182
include("systems/pde/pdesystem.jl")

src/systems/callbacks.jl

Lines changed: 141 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,139 @@ function get_discrete_events(sys::AbstractSystem)
1111
getfield(sys, :discrete_events)
1212
end
1313

14-
struct Callback
15-
eqs::Vector{Equation}
16-
initialize::Union{ImplicitDiscreteSystem, FunctionalAffect, ImperativeAffect}
17-
finalize::ImplicitDiscreteSystem
18-
affect::ImplicitDiscreteSystem
19-
affect_neg::ImplicitDiscreteSystem
20-
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
21-
end
14+
abstract type Callback end
15+
16+
const Affect = Union{ImplicitDiscreteSystem, FunctionalAffect, ImperativeAffect}
2217

2318
# Callbacks:
2419
# mapping (cond) => ImplicitDiscreteSystem
2520
function generate_continuous_callbacks(events, sys)
2621
algeeqs = alg_equations(sys)
27-
callbacks = Callback[]
28-
for (cond, aff) in events
29-
@mtkbuild affect = ImplicitDiscreteSystem([aff, algeeqs], t)
30-
push!(callbacks, Callback(cond, NULL_AFFECT, NULL_AFFECT, affect, affect, SciMLBase.LeftRootFind))
22+
callbacks = MTKContinuousCallback[]
23+
for (cond, affs) in events
24+
@mtkbuild affect = ImplicitDiscreteSystem([affs, algeeqs], t)
25+
push!(callbacks, MTKContinuousCallback(cond, NULL_AFFECT, NULL_AFFECT, affect, affect, SciMLBase.LeftRootFind))
3126
end
3227
callbacks
3328
end
3429

35-
function generate_discrete_callback_system(events, sys)
30+
function generate_discrete_callbacks(events, sys)
31+
algeeqs = alg_equations(sys)
32+
callbacks = MTKDiscreteCallback[]
33+
for (cond, affs) in events
34+
@mtkbuild affect = ImplicitDiscreteSystem([affs, algeeqs], t)
35+
push!(callbacks, MTKDiscreteCallback(cond, NULL_AFFECT, NULL_AFFECT, affect))
36+
end
37+
callbacks
3638
end
3739

38-
function generate_callback_function()
39-
40+
"""
41+
Create a DifferentialEquations callback. A set of continuous callbacks becomes a VectorContinuousCallback.
42+
"""
43+
function create_callback(cbs::Vector{MTKContinuousCallback}, sys; is_discrete = false)
44+
eqs = flatten_equations(cbs)
45+
_, f_iip = generate_custom_function(
46+
sys, [eq.lhs - eq.rhs for eq in eqs], unknowns(sys), parameters(sys);
47+
expression = Val{false})
48+
trigger = (out, u, t, integ) -> f_iip(out, u, parameter_values(integ), t)
49+
50+
affects = []
51+
affect_negs = []
52+
inits = []
53+
finals = []
54+
for cb in cbs
55+
affect = compile_affect(cb.affect)
56+
push!(affects, affect)
57+
isnothing(cb.affect_neg) ? push!(affect_negs, affect) : push!(affect_negs, compile_affect(cb.affect_neg))
58+
push!(inits, compile_affect(cb.initialize, default = SciMLBase.INITALIZE_DEFAULT))
59+
push!(finals, compile_affect(cb.finalize, default = SciMLBase.FINALIZE_DEFAULT))
60+
end
61+
62+
# since there may be different number of conditions and affects,
63+
# we build a map that translates the condition eq. number to the affect number
64+
num_eqs = length.(eqs)
65+
eq2affect = reduce(vcat,
66+
[fill(i, num_eqs[i]) for i in eachindex(affects)])
67+
@assert length(eq2affect) == length(eqs)
68+
@assert maximum(eq2affect) == length(affect_functions)
69+
70+
affect = function (integ, idx)
71+
affects[eq2affect[idx]](integ)
72+
end
73+
affect_neg = function (integ, idx)
74+
f = affect_negs[eq2affect[idx]]
75+
isnothing(f) && return
76+
f(integ)
77+
end
78+
initialize = compile_optional_setup(inits, SciMLBase.INITIALIZE_DEFAULT)
79+
finalize = compile_optional_setup(finals, SciMLBase.FINALIZE_DEFAULT)
80+
81+
return VectorContinuousCallback(trigger, affect; affect_neg, initialize, finalize, rootfind = callback.rootfind, initializealg = SciMLBase.NoInit)
82+
end
83+
84+
function create_callback(cb, sys; is_discrete = false)
85+
is_timed = is_timed_condition(cb)
86+
87+
trigger = if is_discrete
88+
is_timed ? condition(cb) :
89+
compile_condition(callback, sys, unknowns(sys), parameters(sys))
90+
else
91+
_, f_iip = generate_custom_function(
92+
sys, [eq.rhs - eq.lhs for eq in equations(cb)], unknowns(sys), parameters(sys);
93+
expression = Val{false})
94+
(out, u, t, integ) -> f_iip(out, u, parameter_values(integ), t)
95+
end
96+
97+
affect = compile_affect(cb.affect)
98+
affect_neg = isnothing(cb.affect_neg) ? affect_fn : compile_affect(cb.affect_neg)
99+
initialize = compile_affect(cb.initialize, default = SciMLBase.INITIALIZE_DEFAULT)
100+
finalize = compile_affect(cb.finalize, default = SciMLBase.FINALIZE_DEFAULT)
101+
102+
if is_discrete
103+
if is_timed && condition(cb) isa AbstractVector
104+
return PresetTimeCallback(trigger, affect; affect_neg, initialize, finalize, initializealg = SciMLBase.NoInit)
105+
elseif is_timed
106+
return PeriodicCallback(affect, trigger; initialize, finalize)
107+
else
108+
return DiscreteCallback(trigger, affect; affect_neg, initialize, finalize, initializealg = SciMLBase.NoInit)
109+
end
110+
else
111+
return ContinuousCallback(trigger, affect; affect_neg, initialize, finalize, rootfind = callback.rootfind, initializealg = SciMLBase.NoInit)
112+
end
113+
end
114+
115+
function compile_affect(aff; default = nothing)
116+
if aff isa ImplicitDiscreteSystem
117+
function affect!(integrator)
118+
u0map = [u => integrator[u] for u in unknowns(aff)]
119+
pmap = [p => integrator[p] for p in parameters(aff)]
120+
prob = ImplicitDiscreteProblem(aff, u0map, (0, 1), pmap)
121+
sol = solve(prob)
122+
for u in unknowns(aff)
123+
integrator[u] = sol[u][end]
124+
end
125+
for p in parameters(aff)
126+
integrator[p] = sol[p][end]
127+
end
128+
end
129+
elseif aff isa FunctionalAffect || aff isa ImperativeAffect
130+
compile_user_affect(aff, callback, sys, unknowns(sys), parameters(sys))
131+
else
132+
default
133+
end
134+
end
135+
136+
function compile_setup_funcs(funs, default)
137+
all(isnothing, funs) && return default
138+
return let funs = funs
139+
function (cb, u, t, integ)
140+
for func in funs
141+
isnothing(func) ? continue : func(integ)
142+
end
143+
end
144+
end
40145
end
41146

42-
############# Old implementation ###
43147
struct FunctionalAffect
44148
f::Any
45149
sts::Vector
@@ -50,6 +154,22 @@ struct FunctionalAffect
50154
ctx::Any
51155
end
52156

157+
struct MTKContinuousCallback <: Callback
158+
eqs::Vector{Equation}
159+
initialize::Union{Affect, Nothing}
160+
finalize::Union{Affect, Nothing}
161+
affect::Affect
162+
affect_neg::Union{Affect, Nothing}
163+
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
164+
end
165+
166+
struct MTKDiscreteCallback <: Callback
167+
conds::Vector{Equation}
168+
initialize::Union{Affect, Nothing}
169+
finalize::Union{Affect, Nothing}
170+
affect::Affect
171+
end
172+
53173
function FunctionalAffect(f, sts, pars, discretes, ctx = nothing)
54174
# sts & pars contain either pairs: resistor.R => R, or Syms: R
55175
vs = [x isa Pair ? x.first : x for x in sts]
@@ -67,7 +187,7 @@ function FunctionalAffect(; f, sts, pars, discretes, ctx = nothing)
67187
FunctionalAffect(f, sts, pars, discretes, ctx)
68188
end
69189

70-
func(f::FunctionalAffect) = f.f
190+
func(a::FunctionalAffect) = a.f
71191
context(a::FunctionalAffect) = a.ctx
72192
parameters(a::FunctionalAffect) = a.pars
73193
parameters_syms(a::FunctionalAffect) = a.pars_syms
@@ -699,6 +819,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
699819
outputidxs = update_inds,
700820
create_bindings = false,
701821
kwargs...)
822+
@show rf_oop
702823
# applied user-provided function to the generated expression
703824
if postprocess_affect_expr! !== nothing
704825
postprocess_affect_expr!(rf_ip, integ)
@@ -869,13 +990,7 @@ function compile_affect_fn(cb, sys::AbstractTimeDependentSystem, dvs, ps, kwargs
869990
eq_aff = affects(cb)
870991
eq_neg_aff = affect_negs(cb)
871992
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
872-
function compile_optional_affect(aff, default = nothing)
873-
if isnothing(aff) || aff == default
874-
return nothing
875-
else
876-
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
877-
end
878-
end
993+
879994
if eq_neg_aff === eq_aff
880995
affect_neg = affect
881996
else
@@ -1017,13 +1132,15 @@ end
10171132
function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
10181133
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
10191134
end
1135+
10201136
function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...)
10211137
if isnothing(aff) || aff == default
10221138
return nothing
10231139
else
10241140
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
10251141
end
10261142
end
1143+
10271144
function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
10281145
kwargs...)
10291146
cond = condition(cb)

src/systems/codegen_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
234234
if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic()
235235
wrap_code = wrap_code[1]
236236
end
237+
@show build_function(expr, args...)[1]
237238
return build_function(expr, args...; wrap_code, similarto, kwargs...)
238239
end
239240

0 commit comments

Comments
 (0)