Skip to content

Commit dcafe65

Browse files
vyuduAayushSabharwal
authored andcommitted
fix: improve performance of implicit affect
1 parent 348f883 commit dcafe65

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

src/systems/callbacks.jl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -940,32 +940,28 @@ function compile_equational_affect(
940940
dvs_to_access = unknowns(affsys)
941941
ps_to_access = parameters(affsys)
942942

943-
u_getters = [getsym(sys, aff_map[u]) for u in dvs_to_access]
944-
p_getters = [getsym(sys, unPre(p)) for p in ps_to_access]
945-
u_setters = [setsym(sys, u) for u in dvs_to_update]
946-
p_setters = [setsym(sys, p) for p in ps_to_update]
947-
affu_getters = [getsym(affsys, sys_map[u]) for u in dvs_to_update]
948-
affp_getters = [getsym(affsys, sys_map[p]) for p in ps_to_update]
943+
u_getter = getsym(sys, [aff_map[u] for u in dvs_to_access])
944+
p_getter = getsym(sys, [unPre(p) for p in ps_to_access])
945+
u_setter! = setsym(sys, dvs_to_update)
946+
p_setter! = setsym(sys, ps_to_update)
947+
affu_getter = getsym(affsys, [sys_map[u] for u in dvs_to_update])
948+
affp_getter = getsym(affsys, [sys_map[p] for p in ps_to_update])
949949

950950
affprob = ImplicitDiscreteProblem(affsys, [dv => 0 for dv in dvs_to_access],
951951
(0, 0), [p => 0 for p in ps_to_access];
952952
build_initializeprob = false, check_length = false)
953953

954954
function implicit_affect!(integ)
955-
pmap = Pair[p => getp(integ) for (p, getp) in zip(ps_to_access, p_getters)]
956-
u0map = Pair[u => getu(integ)
957-
for (u, getu) in zip(dvs_to_access, u_getters)]
958-
affprob = remake(affprob, u0 = u0map, p = pmap, tspan = (integ.t, integ.t))
955+
new_us = u_getter(integ)
956+
new_ps = p_getter(integ)
957+
affprob = remake(
958+
affprob, u0 = new_us, p = new_ps, tspan = (integ.t, integ.t))
959959
affsol = init(affprob, IDSolve())
960960
(check_error(affsol) === ReturnCode.InitialFailure) &&
961961
throw(UnsolvableCallbackError(all_equations(aff)))
962962

963-
for (setu!, getu) in zip(u_setters, affu_getters)
964-
setu!(integ, getu(affsol))
965-
end
966-
for (setp!, getp) in zip(p_setters, affp_getters)
967-
setp!(integ, getp(affsol))
968-
end
963+
u_setter!(integ, affu_getter(affsol))
964+
p_setter!(integ, affp_getter(affsol))
969965
end
970966
end
971967
end
@@ -1078,3 +1074,16 @@ function continuous_events_toplevel(sys::AbstractSystem)
10781074
end
10791075
return get_continuous_events(sys)
10801076
end
1077+
1078+
"""
1079+
Process the symbolic events of a system.
1080+
"""
1081+
function create_symbolic_events(cont_events, disc_events, sys_eqs, iv)
1082+
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
1083+
sys_eqs)
1084+
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback,
1085+
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1086+
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback,
1087+
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1088+
cont_callbacks, disc_callbacks
1089+
end

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,13 +336,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
336336
throw(ArgumentError("System names must be unique."))
337337
end
338338

339-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
340-
deqs)
341-
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback,
342-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
343-
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback,
344-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
345-
339+
cont_callbacks, disc_callbacks = create_symbolic_events(
340+
continuous_events, discrete_events, deqs, iv)
346341
if is_dde === nothing
347342
is_dde = _check_if_dde(deqs, iv′, systems)
348343
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
270270
Wfact = RefValue(EMPTY_JAC)
271271
Wfact_t = RefValue(EMPTY_JAC)
272272

273-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
274-
deqs)
275-
cont_callbacks = to_cb_vector(continuous_events; CB_TYPE = SymbolicContinuousCallback,
276-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
277-
disc_callbacks = to_cb_vector(discrete_events; CB_TYPE = SymbolicDiscreteCallback,
278-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
273+
cont_callbacks, disc_callbacks = create_symbolic_events(
274+
continuous_events, discrete_events, deqs, iv)
279275

280276
if is_dde === nothing
281277
is_dde = _check_if_dde(deqs, iv′, systems)

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,6 @@ function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
298298
for k in collect(keys(u0map))
299299
v = u0map[k]
300300
if !((op = operation(k)) isa Shift)
301-
isnothing(getunshifted(k)) &&
302-
@warn "Initial condition given in term of current state of the unknown. If `build_initializeprob = false`, this may be overridden by the implicit discrete solver."
303-
304301
updated[k] = v
305302
elseif op.steps > 0
306303
error("Initial conditions must be for the current or past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")

0 commit comments

Comments
 (0)