Skip to content

Commit 30e6af8

Browse files
vyuduAayushSabharwal
authored andcommitted
fix: improve performance of implicit_affect
1 parent 3498e94 commit 30e6af8

File tree

6 files changed

+72
-60
lines changed

6 files changed

+72
-60
lines changed

src/systems/callbacks.jl

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,13 @@ function has_functional_affect(cb)
5656
end
5757

5858
struct AffectSystem
59+
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
5960
system::ImplicitDiscreteSystem
61+
"""Unknowns of the parent ODESystem whose values are modified or accessed by the affect."""
6062
unknowns::Vector
63+
"""Parameters of the parent ODESystem whose values are accessed by the affect."""
6164
parameters::Vector
65+
"""Parameters of the parent ODESystem whose values are modified by the affect."""
6266
discretes::Vector
6367
"""Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
6468
aff_to_sys::Dict
@@ -226,10 +230,12 @@ struct SymbolicContinuousCallback <: AbstractCallback
226230
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
227231

228232
if isnothing(reinitializealg)
229-
any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
230-
[affect, affect_neg, initialize, finalize]) ?
231-
reinitializealg = SciMLBase.CheckInit() :
232-
reinitializealg = SciMLBase.NoInit()
233+
if any(a -> (a isa FunctionalAffect || a isa ImperativeAffect),
234+
[affect, affect_neg, initialize, finalize])
235+
reinitializealg = SciMLBase.CheckInit()
236+
else
237+
reinitializealg = SciMLBase.NoInit()
238+
end
233239
end
234240

235241
new(conditions, make_affect(affect; kwargs...),
@@ -261,8 +267,6 @@ make_affect(affect::Affect; kwargs...) = affect
261267
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
262268
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
263269
isempty(affect) && return nothing
264-
isempty(alg_eqs) && warn_no_algebraic &&
265-
@warn "No algebraic equations were found for the callback defined by $(join(affect, ", ")). If the system has no algebraic equations, this can be disregarded. Otherwise pass in `alg_eqs` to the SymbolicContinuousCallback constructor."
266270
if isnothing(iv)
267271
iv = t_nounits
268272
@warn "No independent variable specified. Defaulting to t_nounits."
@@ -304,7 +308,7 @@ function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
304308
@named affectsys = ImplicitDiscreteSystem(
305309
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
306310
collect(union(pre_params, sys_params)))
307-
affectsys = structural_simplify(affectsys; fully_determined = false)
311+
affectsys = structural_simplify(affectsys; fully_determined = nothing)
308312
# get accessed parameters p from Pre(p) in the callback parameters
309313
accessed_params = filter(isparameter, map(unPre, collect(pre_params)))
310314
union!(accessed_params, sys_params)
@@ -415,7 +419,7 @@ The condition can be one of:
415419
- eqs::Vector{Symbolic} - events trigger when the condition evaluates to true
416420
417421
Arguments:
418-
- iv: The independent variable of the system. This must be specified if the independent variable appaers in one of the equations explicitly, as in x ~ t + 1.
422+
- iv: The independent variable of the system. This must be specified if the independent variable appears in one of the equations explicitly, as in x ~ t + 1.
419423
- alg_eqs: Algebraic equations of the system that must be satisfied after the callback occurs.
420424
"""
421425
struct SymbolicDiscreteCallback <: AbstractCallback
@@ -471,7 +475,6 @@ to_cb_vector(cbs::Union{Nothing, Vector{Nothing}}; kwargs...) = AbstractCallback
471475
to_cb_vector(cb::AbstractCallback; kwargs...) = [cb]
472476
function to_cb_vector(cbs; CB_TYPE = SymbolicContinuousCallback, kwargs...)
473477
if cbs isa Pair
474-
@show cbs
475478
[CB_TYPE(cbs; kwargs...)]
476479
else
477480
Vector{CB_TYPE}([CB_TYPE(cb; kwargs...) for cb in cbs])
@@ -739,13 +742,17 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
739742
[fill(i, num_eqs[i]) for i in eachindex(affects)])
740743
eqs = reduce(vcat, eqs)
741744

742-
affect = function (integ, idx)
743-
affects[eq2affect[idx]](integ)
745+
affect = let eq2affect = eq2affect, affects = affects
746+
function (integ, idx)
747+
affects[eq2affect[idx]](integ)
748+
end
744749
end
745-
affect_neg = function (integ, idx)
746-
f = affect_negs[eq2affect[idx]]
747-
isnothing(f) && return
748-
f(integ)
750+
affect_neg = let eq2affect = eq2affect, affect_negs = affect_negs
751+
function (integ, idx)
752+
f = affect_negs[eq2affect[idx]]
753+
isnothing(f) && return
754+
f(integ)
755+
end
749756
end
750757
initialize = wrap_vector_optional_affect(inits, SciMLBase.INITIALIZE_DEFAULT)
751758
finalize = wrap_vector_optional_affect(finals, SciMLBase.FINALIZE_DEFAULT)
@@ -830,10 +837,10 @@ function compile_affect(
830837
end
831838

832839
function wrap_save_discretes(f, save_idxs)
833-
let save_idxs = save_idxs
840+
let save_idxs = save_idxs, f = f
834841
if f === SciMLBase.INITIALIZE_DEFAULT
835842
(c, u, t, i) -> begin
836-
isnothing(f) || f(c, u, t, i)
843+
f(c, u, t, i)
837844
for idx in save_idxs
838845
SciMLBase.save_discretes!(i, idx)
839846
end
@@ -916,40 +923,43 @@ function compile_equational_affect(
916923
wrap_code = add_integrator_header(sys, integ, :p),
917924
expression = Val{false}, outputidxs = p_idxs, wrap_mtkparameters, cse = false)
918925

919-
return function explicit_affect!(integ)
920-
isempty(dvs_to_update) || u_up!(integ)
921-
isempty(ps_to_update) || p_up!(integ)
922-
reset_jumps && reset_aggregated_jumps!(integ)
926+
return let dvs_to_update = dvs_to_update, ps_to_update = ps_to_update, reset_jump = reset_jump, u_up! = u_up!, p_up! = p_up!
927+
function explicit_affect!(integ)
928+
isempty(dvs_to_update) || u_up!(integ)
929+
isempty(ps_to_update) || p_up!(integ)
930+
reset_jumps && reset_aggregated_jumps!(integ)
931+
end
923932
end
924933
else
925934
return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map,
926-
affsys = affsys, ps_to_update = ps_to_update, aff = aff
935+
affsys = affsys, ps_to_update = ps_to_update, aff = aff, sys = sys
936+
937+
dvs_to_access = unknowns(affsys)
938+
ps_to_access = parameters(affsys)
939+
940+
u_getters = [getsym(sys, aff_map[u]) for u in dvs_to_access]
941+
p_getters = [getsym(sys, unPre(p)) for p in ps_to_access]
942+
u_setters = [setsym(sys, u) for u in dvs_to_update]
943+
p_setters = [setsym(sys, p) for p in ps_to_update]
944+
solu_getters = [getsym(affsys, sys_map[u]) for u in dvs_to_update]
945+
solp_getters = [getsym(affsys, sys_map[p]) for p in ps_to_update]
946+
947+
affprob = ImplicitDiscreteProblem(affsys, u0map, (integ.t, integ.t), pmap;
948+
build_initializeprob = false, check_length = false)
927949

928950
function implicit_affect!(integ)
929-
pmap = Pair[]
930-
for pre_p in parameters(affsys)
931-
p = unPre(pre_p)
932-
pval = isparameter(p) ? integ.ps[p] : integ[p]
933-
push!(pmap, pre_p => pval)
934-
end
935-
u0 = Pair[]
936-
for u in unknowns(affsys)
937-
uval = isparameter(aff_map[u]) ? integ.ps[aff_map[u]] : integ[u]
938-
push!(u0, u => uval)
939-
end
940-
affprob = ImplicitDiscreteProblem(affsys, u0, (integ.t, integ.t), pmap;
941-
build_initializeprob = false, check_length = false)
942-
@show pmap
943-
@show u0
951+
pmap = [p => getp(integ) for (p, getp) in zip(parameters(affsys), p_getters)]
952+
u0map = [u => getu(integ) for (u, getu) in zip(unknowns(affsys), u_getters)]
953+
affprob = remake(affprob, u0 = u0map, p = pmap)
944954
affsol = init(affprob, IDSolve())
945-
@show affsol
946955
(check_error(affsol) === ReturnCode.InitialFailure) &&
947956
throw(UnsolvableCallbackError(all_equations(aff)))
948-
for u in dvs_to_update
949-
integ[u] = affsol[sys_map[u]]
957+
958+
for (setu!, getu) in zip(u_setters, solu_getters)
959+
setu!(integ, getu(affsol))
950960
end
951-
for p in ps_to_update
952-
integ.ps[p] = affsol[sys_map[p]]
961+
for (setp!, getp) in zip(p_setters, solp_getters)
962+
setp!(integ, getp(affsol))
953963
end
954964
end
955965
end

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
299299
v = u0map[k]
300300
if !((op = operation(k)) isa Shift)
301301
isnothing(getunshifted(k)) &&
302-
@warn "Initial condition given in term of current state of the unknown. If `build_initializeprob = false`, this may be overriden by the implicit discrete solver."
302+
@warn "Initial condition given in term of current state of the unknown. If `build_initializeprob = false`, this may be overridden by the implicit discrete solver."
303303

304304
updated[k] = v
305305
elseif op.steps > 0

src/systems/jumps/jumpsystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
7575
Type of the system.
7676
"""
7777
connector_type::Any
78+
"""
79+
A `Vector{SymbolicContinuousCallback}` that model events.
80+
The integrator will use root finding to guarantee that it steps at each zero crossing.
81+
"""
7882
continuous_events::Vector{SymbolicContinuousCallback}
7983
"""
8084
A `Vector{SymbolicDiscreteCallback}` that models events. Symbolic

src/systems/problem_utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
380380
vals = promote_to_concrete(vals; tofloat = tofloat, use_union = false)
381381
end
382382

383-
if container_type <: Tuple
383+
if isempty(vals)
384+
return nothing
385+
elseif container_type <: Tuple
384386
return (vals...,)
385387
else
386388
return SymbolicUtils.Code.create_array(container_type, eltype(vals), Val{1}(),

src/systems/systems.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ function structural_simplify(
4343
end
4444
if newsys isa DiscreteSystem &&
4545
any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
46-
#error("""
47-
# Encountered algebraic equations when simplifying discrete system. Please construct \
48-
# an ImplicitDiscreteSystem instead.
49-
#""")
5046
end
5147
for pass in additional_passes
5248
newsys = pass(newsys)

test/symbolic_events.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,43 @@ affect_neg = [x ~ 1]
2222
e = SymbolicContinuousCallback(eqs[])
2323
@test e isa SymbolicContinuousCallback
2424
@test isequal(equations(e), eqs)
25-
@test e.affect == nothing
26-
@test e.affect_neg == nothing
25+
@test e.affect === nothing
26+
@test e.affect_neg === nothing
2727
@test e.rootfind == SciMLBase.LeftRootFind
2828

2929
e = SymbolicContinuousCallback(eqs)
3030
@test e isa SymbolicContinuousCallback
3131
@test isequal(equations(e), eqs)
32-
@test e.affect == nothing
33-
@test e.affect_neg == nothing
32+
@test e.affect === nothing
33+
@test e.affect_neg === nothing
3434
@test e.rootfind == SciMLBase.LeftRootFind
3535

3636
e = SymbolicContinuousCallback(eqs, nothing)
3737
@test e isa SymbolicContinuousCallback
3838
@test isequal(equations(e), eqs)
39-
@test e.affect == nothing
40-
@test e.affect_neg == nothing
39+
@test e.affect === nothing
40+
@test e.affect_neg === nothing
4141
@test e.rootfind == SciMLBase.LeftRootFind
4242

4343
e = SymbolicContinuousCallback(eqs[], nothing)
4444
@test e isa SymbolicContinuousCallback
4545
@test isequal(equations(e), eqs)
46-
@test e.affect == nothing
47-
@test e.affect_neg == nothing
46+
@test e.affect === nothing
47+
@test e.affect_neg === nothing
4848
@test e.rootfind == SciMLBase.LeftRootFind
4949

5050
e = SymbolicContinuousCallback(eqs => nothing)
5151
@test e isa SymbolicContinuousCallback
5252
@test isequal(equations(e), eqs)
53-
@test e.affect == nothing
54-
@test e.affect_neg == nothing
53+
@test e.affect === nothing
54+
@test e.affect_neg === nothing
5555
@test e.rootfind == SciMLBase.LeftRootFind
5656

5757
e = SymbolicContinuousCallback(eqs[] => nothing)
5858
@test e isa SymbolicContinuousCallback
5959
@test isequal(equations(e), eqs)
60-
@test e.affect == nothing
61-
@test e.affect_neg == nothing
60+
@test e.affect === nothing
61+
@test e.affect_neg === nothing
6262
@test e.rootfind == SciMLBase.LeftRootFind
6363

6464
## With affect

0 commit comments

Comments
 (0)