Skip to content

Commit 4d0871d

Browse files
committed
fix: improve performance of implicit_affect
1 parent e98781a commit 4d0871d

File tree

6 files changed

+69
-60
lines changed

6 files changed

+69
-60
lines changed

src/systems/callbacks.jl

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,13 @@ function has_functional_affect(cb)
5656
end
5757

5858
struct AffectSystem
59+
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
5960
system::ImplicitDiscreteSystem
61+
"""Unknowns of the parent ODESystem whose values are modified or accessed by the affect."""
6062
unknowns::Vector
63+
"""Parameters of the parent ODESystem whose values are accessed by the affect."""
6164
parameters::Vector
65+
"""Parameters of the parent ODESystem whose values are modified by the affect."""
6266
discretes::Vector
6367
"""Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
6468
aff_to_sys::Dict
@@ -226,10 +230,12 @@ struct SymbolicContinuousCallback <: AbstractCallback
226230
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
227231

228232
if isnothing(reinitializealg)
229-
any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
230-
[affect, affect_neg, initialize, finalize]) ?
231-
reinitializealg = SciMLBase.CheckInit() :
232-
reinitializealg = SciMLBase.NoInit()
233+
if any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
234+
[affect, affect_neg, initialize, finalize])
235+
reinitializealg = SciMLBase.CheckInit()
236+
else
237+
reinitializealg = SciMLBase.NoInit()
238+
end
233239
end
234240

235241
new(conditions, make_affect(affect; kwargs...),
@@ -261,8 +267,6 @@ make_affect(affect::Affect; kwargs...) = affect
261267
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
262268
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
263269
isempty(affect) && return nothing
264-
isempty(alg_eqs) && warn_no_algebraic &&
265-
@warn "No algebraic equations were found for the callback defined by $(join(affect, ", ")). If the system has no algebraic equations, this can be disregarded. Otherwise pass in `alg_eqs` to the SymbolicContinuousCallback constructor."
266270
if isnothing(iv)
267271
iv = t_nounits
268272
@warn "No independent variable specified. Defaulting to t_nounits."
@@ -304,7 +308,7 @@ function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
304308
@named affectsys = ImplicitDiscreteSystem(
305309
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
306310
collect(union(pre_params, sys_params)))
307-
affectsys = structural_simplify(affectsys; fully_determined = false)
311+
affectsys = structural_simplify(affectsys; fully_determined = nothing)
308312
# get accessed parameters p from Pre(p) in the callback parameters
309313
accessed_params = filter(isparameter, map(unPre, collect(pre_params)))
310314
union!(accessed_params, sys_params)
@@ -415,7 +419,7 @@ The condition can be one of:
415419
- eqs::Vector{Symbolic} - events trigger when the condition evaluates to true
416420
417421
Arguments:
418-
- iv: The independent variable of the system. This must be specified if the independent variable appaers in one of the equations explicitly, as in x ~ t + 1.
422+
- iv: The independent variable of the system. This must be specified if the independent variable appears in one of the equations explicitly, as in x ~ t + 1.
419423
- alg_eqs: Algebraic equations of the system that must be satisfied after the callback occurs.
420424
"""
421425
struct SymbolicDiscreteCallback <: AbstractCallback
@@ -473,7 +477,6 @@ to_cb_vector(cbs::Union{Nothing, Vector{Nothing}}; kwargs...) = AbstractCallback
473477
to_cb_vector(cb::AbstractCallback; kwargs...) = [cb]
474478
function to_cb_vector(cbs; CB_TYPE = SymbolicContinuousCallback, kwargs...)
475479
if cbs isa Pair
476-
@show cbs
477480
[CB_TYPE(cbs; kwargs...)]
478481
else
479482
Vector{CB_TYPE}([CB_TYPE(cb; kwargs...) for cb in cbs])
@@ -741,13 +744,17 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
741744
[fill(i, num_eqs[i]) for i in eachindex(affects)])
742745
eqs = reduce(vcat, eqs)
743746

744-
affect = function (integ, idx)
745-
affects[eq2affect[idx]](integ)
747+
affect = let eq2affect = eq2affect, affects = affects
748+
function (integ, idx)
749+
affects[eq2affect[idx]](integ)
750+
end
746751
end
747-
affect_neg = function (integ, idx)
748-
f = affect_negs[eq2affect[idx]]
749-
isnothing(f) && return
750-
f(integ)
752+
affect_neg = let eq2affect = eq2affect, affect_negs = affect_negs
753+
function (integ, idx)
754+
f = affect_negs[eq2affect[idx]]
755+
isnothing(f) && return
756+
f(integ)
757+
end
751758
end
752759
initialize = wrap_vector_optional_affect(inits, SciMLBase.INITIALIZE_DEFAULT)
753760
finalize = wrap_vector_optional_affect(finals, SciMLBase.FINALIZE_DEFAULT)
@@ -832,10 +839,10 @@ function compile_affect(
832839
end
833840

834841
function wrap_save_discretes(f, save_idxs)
835-
let save_idxs = save_idxs
842+
let save_idxs = save_idxs, f = f
836843
if f === SciMLBase.INITIALIZE_DEFAULT
837844
(c, u, t, i) -> begin
838-
isnothing(f) || f(c, u, t, i)
845+
f(c, u, t, i)
839846
for idx in save_idxs
840847
SciMLBase.save_discretes!(i, idx)
841848
end
@@ -918,40 +925,40 @@ function compile_equational_affect(
918925
wrap_code = add_integrator_header(sys, integ, :p),
919926
expression = Val{false}, outputidxs = p_idxs, wrap_mtkparameters, cse = false)
920927

921-
return function explicit_affect!(integ)
922-
isempty(dvs_to_update) || u_up!(integ)
923-
isempty(ps_to_update) || p_up!(integ)
924-
reset_jumps && reset_aggregated_jumps!(integ)
928+
return let dvs_to_update = dvs_to_update, ps_to_update = ps_to_update, reset_jump = reset_jump, u_up! = u_up!, p_up! = p_up!
929+
function explicit_affect!(integ)
930+
isempty(dvs_to_update) || u_up!(integ)
931+
isempty(ps_to_update) || p_up!(integ)
932+
reset_jumps && reset_aggregated_jumps!(integ)
933+
end
925934
end
926935
else
927936
return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map,
928-
affsys = affsys, ps_to_update = ps_to_update, aff = aff
937+
affsys = affsys, ps_to_update = ps_to_update, aff = aff, sys = sys
938+
939+
affu_getters = [getsym(affsys, u) for u in unknowns(affsys)]
940+
affp_getters = [getsym(affsys, unPre(p)) for p in parameters(affsys)]
941+
u_setters = [setsym(sys, u) for u in dvs_to_update]
942+
p_setters = [setsym(sys, p) for p in ps_to_update]
943+
solu_getters = [getsym(affsys, sys_map[u]) for u in dvs_to_update]
944+
solp_getters = [getsym(affsys, sys_map[p]) for p in ps_to_update]
945+
946+
affprob = ImplicitDiscreteProblem(affsys, u0map, (integ.t, integ.t), pmap;
947+
build_initializeprob = false, check_length = false)
929948

930949
function implicit_affect!(integ)
931-
pmap = Pair[]
932-
for pre_p in parameters(affsys)
933-
p = unPre(pre_p)
934-
pval = isparameter(p) ? integ.ps[p] : integ[p]
935-
push!(pmap, pre_p => pval)
936-
end
937-
u0 = Pair[]
938-
for u in unknowns(affsys)
939-
uval = isparameter(aff_map[u]) ? integ.ps[aff_map[u]] : integ[u]
940-
push!(u0, u => uval)
941-
end
942-
affprob = ImplicitDiscreteProblem(affsys, u0, (integ.t, integ.t), pmap;
943-
build_initializeprob = false, check_length = false)
944-
@show pmap
945-
@show u0
950+
pmap = [p => getp(integ) for (p, getp) in zip(parameters(affsys), p_getters)]
951+
u0map = [u => getu(integ) for (u, getu) in zip(unknowns(affsys), u_getters)]
952+
affprob = remake(affprob, u0 = u0map, p = pmap)
946953
affsol = init(affprob, IDSolve())
947-
@show affsol
948954
(check_error(affsol) === ReturnCode.InitialFailure) &&
949955
throw(UnsolvableCallbackError(all_equations(aff)))
950-
for u in dvs_to_update
951-
integ[u] = affsol[sys_map[u]]
956+
957+
for (setu!, getu) in zip(u_setters, solu_getters)
958+
setu!(integ, getu(affsol))
952959
end
953-
for p in ps_to_update
954-
integ.ps[p] = affsol[sys_map[p]]
960+
for (setp!, getp) in zip(p_setters, solp_getters)
961+
setp!(integ, getp(affsol))
955962
end
956963
end
957964
end

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
299299
v = u0map[k]
300300
if !((op = operation(k)) isa Shift)
301301
isnothing(getunshifted(k)) &&
302-
@warn "Initial condition given in term of current state of the unknown. If `build_initializeprob = false`, this may be overriden by the implicit discrete solver."
302+
@warn "Initial condition given in term of current state of the unknown. If `build_initializeprob = false`, this may be overridden by the implicit discrete solver."
303303

304304
updated[k] = v
305305
elseif op.steps > 0

src/systems/jumps/jumpsystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
7575
Type of the system.
7676
"""
7777
connector_type::Any
78+
"""
79+
A `Vector{SymbolicContinuousCallback}` that model events.
80+
The integrator will use root finding to guarantee that it steps at each zero crossing.
81+
"""
7882
continuous_events::Vector{SymbolicContinuousCallback}
7983
"""
8084
A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic

src/systems/problem_utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
380380
vals = promote_to_concrete(vals; tofloat = tofloat, use_union = false)
381381
end
382382

383-
if container_type <: Tuple
383+
if isempty(vals)
384+
return nothing
385+
elseif container_type <: Tuple
384386
return (vals...,)
385387
else
386388
return SymbolicUtils.Code.create_array(container_type, eltype(vals), Val{1}(),

src/systems/systems.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ function structural_simplify(
4242
end
4343
if newsys isa DiscreteSystem &&
4444
any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
45-
#error("""
46-
# Encountered algebraic equations when simplifying discrete system. Please construct \
47-
# an ImplicitDiscreteSystem instead.
48-
#""")
4945
end
5046
for pass in additional_passes
5147
newsys = pass(newsys)

test/symbolic_events.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,43 @@ affect_neg = [x ~ 1]
2222
e = SymbolicContinuousCallback(eqs[])
2323
@test e isa SymbolicContinuousCallback
2424
@test isequal(equations(e), eqs)
25-
@test e.affect == nothing
26-
@test e.affect_neg == nothing
25+
@test e.affect === nothing
26+
@test e.affect_neg === nothing
2727
@test e.rootfind == SciMLBase.LeftRootFind
2828

2929
e = SymbolicContinuousCallback(eqs)
3030
@test e isa SymbolicContinuousCallback
3131
@test isequal(equations(e), eqs)
32-
@test e.affect == nothing
33-
@test e.affect_neg == nothing
32+
@test e.affect === nothing
33+
@test e.affect_neg === nothing
3434
@test e.rootfind == SciMLBase.LeftRootFind
3535

3636
e = SymbolicContinuousCallback(eqs, nothing)
3737
@test e isa SymbolicContinuousCallback
3838
@test isequal(equations(e), eqs)
39-
@test e.affect == nothing
40-
@test e.affect_neg == nothing
39+
@test e.affect === nothing
40+
@test e.affect_neg === nothing
4141
@test e.rootfind == SciMLBase.LeftRootFind
4242

4343
e = SymbolicContinuousCallback(eqs[], nothing)
4444
@test e isa SymbolicContinuousCallback
4545
@test isequal(equations(e), eqs)
46-
@test e.affect == nothing
47-
@test e.affect_neg == nothing
46+
@test e.affect === nothing
47+
@test e.affect_neg === nothing
4848
@test e.rootfind == SciMLBase.LeftRootFind
4949

5050
e = SymbolicContinuousCallback(eqs => nothing)
5151
@test e isa SymbolicContinuousCallback
5252
@test isequal(equations(e), eqs)
53-
@test e.affect == nothing
54-
@test e.affect_neg == nothing
53+
@test e.affect === nothing
54+
@test e.affect_neg === nothing
5555
@test e.rootfind == SciMLBase.LeftRootFind
5656

5757
e = SymbolicContinuousCallback(eqs[] => nothing)
5858
@test e isa SymbolicContinuousCallback
5959
@test isequal(equations(e), eqs)
60-
@test e.affect == nothing
61-
@test e.affect_neg == nothing
60+
@test e.affect === nothing
61+
@test e.affect_neg === nothing
6262
@test e.rootfind == SciMLBase.LeftRootFind
6363

6464
## With affect

0 commit comments

Comments
 (0)