Skip to content

Commit 78df83c

Browse files
fix: compile symbolic affects after mtkcompile in complete
1 parent 20fe296 commit 78df83c

File tree

4 files changed

+190
-119
lines changed

4 files changed

+190
-119
lines changed

src/systems/abstractsystem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,16 @@ function complete(
646646
if add_initial_parameters
647647
sys = add_initialization_parameters(sys; split)
648648
end
649+
if has_continuous_events(sys) && is_time_dependent(sys)
650+
@set! sys.continuous_events = complete.(
651+
get_continuous_events(sys); iv = get_iv(sys),
652+
alg_eqs = [alg_equations(sys); observed(sys)])
653+
end
654+
if has_discrete_events(sys) && is_time_dependent(sys)
655+
@set! sys.discrete_events = complete.(
656+
get_discrete_events(sys); iv = get_iv(sys),
657+
alg_eqs = [alg_equations(sys); observed(sys)])
658+
end
649659
end
650660
if split && has_index_cache(sys)
651661
@set! sys.index_cache = IndexCache(sys)

src/systems/callbacks.jl

Lines changed: 154 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@ function has_functional_affect(cb)
44
affects(cb) isa ImperativeAffect
55
end
66

7+
struct SymbolicAffect{K}
8+
affect::Vector{Equation}
9+
alg_eqs::Vector{Equation}
10+
discrete_parameters::Vector{Any}
11+
end
12+
13+
function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[],
14+
discrete_parameters = Any[], kwargs...)
15+
SymbolicAffect(affect, alg_eqs, discrete_parameters)
16+
end
17+
function SymbolicAffect(affect::SymbolicAffect; kwargs...)
18+
SymbolicAffect(affect.affect; alg_eqs = affect.alg_eqs,
19+
discrete_parameters = affect.discrete_parameters, kwargs...)
20+
end
21+
SymbolicAffect(affect; kwargs...) = affect
22+
723
struct AffectSystem
824
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
925
system::AbstractSystem
@@ -15,6 +31,72 @@ struct AffectSystem
1531
discretes::Vector
1632
end
1733

34+
function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
35+
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
36+
discrete_parameters = spec.discrete_parameters, kwargs...)
37+
end
38+
39+
function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[],
40+
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
41+
isempty(affect) && return nothing
42+
if isnothing(iv)
43+
iv = t_nounits
44+
@warn "No independent variable specified. Defaulting to t_nounits."
45+
end
46+
47+
discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
48+
discrete_parameters = unwrap.(discrete_parameters)
49+
50+
for p in discrete_parameters
51+
occursin(unwrap(iv), unwrap(p)) ||
52+
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
53+
end
54+
55+
dvs = OrderedSet()
56+
params = OrderedSet()
57+
_varsbuf = Set()
58+
for eq in affect
59+
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
60+
symbolic_type(eq.lhs) === NotSymbolic())
61+
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
62+
end
63+
collect_vars!(dvs, params, eq, iv; op = Pre)
64+
empty!(_varsbuf)
65+
vars!(_varsbuf, eq; op = Pre)
66+
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
67+
union!(params, _varsbuf)
68+
diffvs = collect_applied_operators(eq, Differential)
69+
union!(dvs, diffvs)
70+
end
71+
for eq in alg_eqs
72+
collect_vars!(dvs, params, eq, iv)
73+
end
74+
pre_params = filter(haspre value, params)
75+
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
76+
discretes = map(tovar, discrete_parameters)
77+
dvs = collect(dvs)
78+
_dvs = map(default_toterm, dvs)
79+
80+
rev_map = Dict(zip(discrete_parameters, discretes))
81+
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
82+
affect = Symbolics.fast_substitute(affect, subs)
83+
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)
84+
85+
@named affectsys = System(
86+
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
87+
collect(union(pre_params, sys_params)); is_discrete = true)
88+
affectsys = mtkcompile(affectsys; fully_determined = nothing)
89+
# get accessed parameters p from Pre(p) in the callback parameters
90+
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
91+
union!(accessed_params, sys_params)
92+
93+
# add scalarized unknowns to the map.
94+
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
95+
96+
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
97+
collect(discrete_parameters))
98+
end
99+
18100
system(a::AffectSystem) = a.system
19101
discretes(a::AffectSystem) = a.discretes
20102
unknowns(a::AffectSystem) = a.unknowns
@@ -159,40 +241,40 @@ will run as soon as the solver starts, while finalization affects will be execut
159241
"""
160242
struct SymbolicContinuousCallback <: AbstractCallback
161243
conditions::Vector{Equation}
162-
affect::Union{Affect, Nothing}
163-
affect_neg::Union{Affect, Nothing}
164-
initialize::Union{Affect, Nothing}
165-
finalize::Union{Affect, Nothing}
244+
affect::Union{Affect, SymbolicAffect, Nothing}
245+
affect_neg::Union{Affect, SymbolicAffect, Nothing}
246+
initialize::Union{Affect, SymbolicAffect, Nothing}
247+
finalize::Union{Affect, SymbolicAffect, Nothing}
166248
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
167249
reinitializealg::SciMLBase.DAEInitializationAlgorithm
250+
end
168251

169-
function SymbolicContinuousCallback(
170-
conditions::Union{Equation, Vector{Equation}},
171-
affect = nothing;
172-
affect_neg = affect,
173-
initialize = nothing,
174-
finalize = nothing,
175-
rootfind = SciMLBase.LeftRootFind,
176-
reinitializealg = nothing,
177-
kwargs...)
178-
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
179-
180-
if isnothing(reinitializealg)
181-
if any(a -> a isa ImperativeAffect,
182-
[affect, affect_neg, initialize, finalize])
183-
reinitializealg = SciMLBase.CheckInit()
184-
else
185-
reinitializealg = SciMLBase.NoInit()
186-
end
252+
function SymbolicContinuousCallback(
253+
conditions::Union{Equation, Vector{Equation}},
254+
affect = nothing;
255+
affect_neg = affect,
256+
initialize = nothing,
257+
finalize = nothing,
258+
rootfind = SciMLBase.LeftRootFind,
259+
reinitializealg = nothing,
260+
kwargs...)
261+
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
262+
263+
if isnothing(reinitializealg)
264+
if any(a -> a isa ImperativeAffect,
265+
[affect, affect_neg, initialize, finalize])
266+
reinitializealg = SciMLBase.CheckInit()
267+
else
268+
reinitializealg = SciMLBase.NoInit()
187269
end
270+
end
188271

189-
new(conditions, make_affect(affect; kwargs...),
190-
make_affect(affect_neg; kwargs...),
191-
make_affect(initialize; kwargs...), make_affect(
192-
finalize; kwargs...),
193-
rootfind, reinitializealg)
194-
end # Default affect to nothing
195-
end
272+
SymbolicContinuousCallback(conditions, SymbolicAffect(affect; kwargs...),
273+
SymbolicAffect(affect_neg; kwargs...),
274+
SymbolicAffect(initialize; kwargs...), SymbolicAffect(
275+
finalize; kwargs...),
276+
rootfind, reinitializealg)
277+
end # Default affect to nothing
196278

197279
function SymbolicContinuousCallback(p::Pair, args...; kwargs...)
198280
SymbolicContinuousCallback(p[1], p[2], args...; kwargs...)
@@ -207,72 +289,18 @@ function SymbolicContinuousCallback(cb::Tuple, args...; kwargs...)
207289
end
208290
end
209291

292+
function complete(cb::SymbolicContinuousCallback; kwargs...)
293+
SymbolicContinuousCallback(cb.conditions, make_affect(cb.affect; kwargs...),
294+
make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...),
295+
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg)
296+
end
297+
298+
make_affect(affect::SymbolicAffect; kwargs...) = AffectSystem(affect; kwargs...)
210299
make_affect(affect::Nothing; kwargs...) = nothing
211300
make_affect(affect::Tuple; kwargs...) = ImperativeAffect(affect...)
212301
make_affect(affect::NamedTuple; kwargs...) = ImperativeAffect(; affect...)
213302
make_affect(affect::Affect; kwargs...) = affect
214303

215-
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
216-
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
217-
isempty(affect) && return nothing
218-
if isnothing(iv)
219-
iv = t_nounits
220-
@warn "No independent variable specified. Defaulting to t_nounits."
221-
end
222-
223-
discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
224-
discrete_parameters = unwrap.(discrete_parameters)
225-
226-
for p in discrete_parameters
227-
occursin(unwrap(iv), unwrap(p)) ||
228-
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
229-
end
230-
231-
dvs = OrderedSet()
232-
params = OrderedSet()
233-
_varsbuf = Set()
234-
for eq in affect
235-
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
236-
symbolic_type(eq.lhs) === NotSymbolic())
237-
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
238-
end
239-
collect_vars!(dvs, params, eq, iv; op = Pre)
240-
empty!(_varsbuf)
241-
vars!(_varsbuf, eq; op = Pre)
242-
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
243-
union!(params, _varsbuf)
244-
diffvs = collect_applied_operators(eq, Differential)
245-
union!(dvs, diffvs)
246-
end
247-
for eq in alg_eqs
248-
collect_vars!(dvs, params, eq, iv)
249-
end
250-
pre_params = filter(haspre value, params)
251-
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
252-
discretes = map(tovar, discrete_parameters)
253-
dvs = collect(dvs)
254-
_dvs = map(default_toterm, dvs)
255-
256-
rev_map = Dict(zip(discrete_parameters, discretes))
257-
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
258-
affect = Symbolics.fast_substitute(affect, subs)
259-
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)
260-
261-
@named affectsys = System(
262-
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
263-
collect(union(pre_params, sys_params)); is_discrete = true)
264-
affectsys = mtkcompile(affectsys; fully_determined = nothing)
265-
# get accessed parameters p from Pre(p) in the callback parameters
266-
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
267-
union!(accessed_params, sys_params)
268-
269-
# add scalarized unknowns to the map.
270-
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
271-
272-
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
273-
collect(discrete_parameters))
274-
end
275-
276304
function make_affect(affect; kwargs...)
277305
error("Malformed affect $(affect). This should be a vector of equations or a tuple specifying a functional affect.")
278306
end
@@ -374,30 +402,30 @@ Arguments:
374402
"""
375403
struct SymbolicDiscreteCallback <: AbstractCallback
376404
conditions::Union{Number, Vector{<:Number}, Symbolic{Bool}}
377-
affect::Union{Affect, Nothing}
378-
initialize::Union{Affect, Nothing}
379-
finalize::Union{Affect, Nothing}
405+
affect::Union{Affect, SymbolicAffect, Nothing}
406+
initialize::Union{Affect, SymbolicAffect, Nothing}
407+
finalize::Union{Affect, SymbolicAffect, Nothing}
380408
reinitializealg::SciMLBase.DAEInitializationAlgorithm
409+
end
381410

382-
function SymbolicDiscreteCallback(
383-
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
384-
initialize = nothing, finalize = nothing,
385-
reinitializealg = nothing, kwargs...)
386-
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
387-
388-
if isnothing(reinitializealg)
389-
if any(a -> a isa ImperativeAffect,
390-
[affect, initialize, finalize])
391-
reinitializealg = SciMLBase.CheckInit()
392-
else
393-
reinitializealg = SciMLBase.NoInit()
394-
end
411+
function SymbolicDiscreteCallback(
412+
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
413+
initialize = nothing, finalize = nothing,
414+
reinitializealg = nothing, kwargs...)
415+
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
416+
417+
if isnothing(reinitializealg)
418+
if any(a -> a isa ImperativeAffect,
419+
[affect, initialize, finalize])
420+
reinitializealg = SciMLBase.CheckInit()
421+
else
422+
reinitializealg = SciMLBase.NoInit()
395423
end
396-
new(c, make_affect(affect; kwargs...),
397-
make_affect(initialize; kwargs...),
398-
make_affect(finalize; kwargs...), reinitializealg)
399-
end # Default affect to nothing
400-
end
424+
end
425+
SymbolicDiscreteCallback(c, SymbolicAffect(affect; kwargs...),
426+
SymbolicAffect(initialize; kwargs...),
427+
SymbolicAffect(finalize; kwargs...), reinitializealg)
428+
end # Default affect to nothing
401429

402430
function SymbolicDiscreteCallback(p::Pair, args...; kwargs...)
403431
SymbolicDiscreteCallback(p[1], p[2], args...; kwargs...)
@@ -412,6 +440,12 @@ function SymbolicDiscreteCallback(cb::Tuple, args...; kwargs...)
412440
end
413441
end
414442

443+
function complete(cb::SymbolicDiscreteCallback; kwargs...)
444+
SymbolicDiscreteCallback(cb.conditions, make_affect(cb.affect; kwargs...),
445+
make_affect(cb.initialize; kwargs...),
446+
make_affect(cb.finalize; kwargs...), cb.reinitializealg)
447+
end
448+
415449
function is_timed_condition(condition::T) where {T}
416450
if T === Num
417451
false
@@ -457,6 +491,12 @@ function namespace_affects(affect::AffectSystem, s)
457491
renamespace.((s,), parameters(affect)),
458492
renamespace.((s,), discretes(affect)))
459493
end
494+
function namespace_affects(affect::SymbolicAffect, s)
495+
SymbolicAffect(
496+
namespace_equation.(affect.affect, (s,)), namespace_equation.(affect.alg_eqs, (s,)),
497+
renamespace.((s,), affect.discrete_parameters))
498+
end
499+
460500
namespace_affects(af::Nothing, s) = nothing
461501

462502
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
@@ -1060,12 +1100,8 @@ end
10601100
"""
10611101
Process the symbolic events of a system.
10621102
"""
1063-
function create_symbolic_events(cont_events, disc_events, sys_eqs, iv)
1064-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
1065-
sys_eqs)
1066-
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback,
1067-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1068-
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback,
1069-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1103+
function create_symbolic_events(cont_events, disc_events)
1104+
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback)
1105+
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback)
10701106
cont_callbacks, disc_callbacks
10711107
end

src/systems/system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
389389
end
390390
continuous_events,
391391
discrete_events = create_symbolic_events(
392-
continuous_events, discrete_events, eqs, iv)
392+
continuous_events, discrete_events)
393393

394394
if iv === nothing && (!isempty(continuous_events) || !isempty(discrete_events))
395395
throw(EventsInTimeIndependentSystemError(continuous_events, discrete_events))

test/symbolic_events.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,3 +1378,28 @@ end
13781378
@test SciMLBase.successful_retcode(sol)
13791379
@test sol[x, end]1.0 atol=1e-6
13801380
end
1381+
1382+
@testset "Symbolic affects are compiled in `complete`" begin
1383+
@parameters g
1384+
@variables x(t) [state_priority = 10.0] y(t) [guess = 1.0]
1385+
@variables λ(t) [guess = 1.0]
1386+
eqs = [D(D(x)) ~ λ * x
1387+
D(D(y)) ~ λ * y - g
1388+
x^2 + y^2 ~ 1]
1389+
cevts = [[x ~ 0.0] => [D(x) ~ Pre(D(x)) + 1sign(Pre(D(x)))]]
1390+
@named pend = System(eqs, t; continuous_events = cevts)
1391+
1392+
scc = only(continuous_events(pend))
1393+
@test scc.affect isa ModelingToolkit.SymbolicAffect
1394+
1395+
pend = mtkcompile(pend)
1396+
1397+
scc = only(continuous_events(pend))
1398+
@test scc.affect isa ModelingToolkit.AffectSystem
1399+
@test length(ModelingToolkit.all_equations(scc.affect)) == 5 # 1 affect, 3 algebraic, 1 observed
1400+
1401+
u0 = [x => -1/2, D(x) => 1/2, g => 1]
1402+
prob = ODEProblem(pend, u0, (0.0, 5.0))
1403+
sol = solve(prob, FBDF())
1404+
@test SciMLBase.successful_retcode(sol)
1405+
end

0 commit comments

Comments
 (0)