Skip to content

Commit f852656

Browse files
committed
feat: specify discrete_parameters
1 parent bb4ef14 commit f852656

File tree

4 files changed

+1303
-1243
lines changed

4 files changed

+1303
-1243
lines changed

src/systems/callbacks.jl

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
219219
function SymbolicContinuousCallback(
220220
conditions::Union{Equation, Vector{Equation}},
221221
affect = nothing;
222+
discrete_parameters = Any[],
222223
affect_neg = affect,
223224
initialize = nothing,
224225
finalize = nothing,
@@ -227,8 +228,8 @@ struct SymbolicContinuousCallback <: AbstractCallback
227228
algeeqs = Equation[])
228229

229230
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
230-
new(conditions, make_affect(affect; iv, algeeqs), make_affect(affect_neg; iv, algeeqs),
231-
make_affect(initialize; iv, algeeqs), make_affect(finalize; iv, algeeqs), rootfind)
231+
new(conditions, make_affect(affect; iv, algeeqs, discrete_parameters), make_affect(affect_neg; iv, algeeqs, discrete_parameters),
232+
make_affect(initialize; iv, algeeqs, discrete_parameters), make_affect(finalize; iv, algeeqs, discrete_parameters), rootfind)
232233
end # Default affect to nothing
233234
end
234235

@@ -240,10 +241,15 @@ make_affect(affect::Tuple; kwargs...) = FunctionalAffect(affect...)
240241
make_affect(affect::NamedTuple; kwargs...) = FunctionalAffect(; affect...)
241242
make_affect(affect::Affect; kwargs...) = affect
242243

243-
function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs::Vector{Equation} = Equation[])
244+
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[], iv = nothing, algeeqs::Vector{Equation} = Equation[])
244245
isempty(affect) && return nothing
245246
isempty(algeeqs) && @warn "No algebraic equations were found for the callback defined by $(join(affect, ", ")). If the system has no algebraic equations, this can be disregarded. Otherwise pass in `algeeqs` to the SymbolicContinuousCallback constructor."
246247

248+
for p in discretes
249+
# Check if p is time-dependent
250+
false && error("Non-time dependent parameter $p passed in as a discrete. Must be declared as $p(t).")
251+
end
252+
247253
explicit = true
248254
dvs = OrderedSet()
249255
params = OrderedSet()
@@ -265,38 +271,21 @@ function make_affect(affect::Vector{Equation}; iv = nothing, algeeqs::Vector{Equ
265271
isnothing(iv) && @warn "No independent variable specified and could not be inferred. If the iv appears in an affect equation explicitly, like x ~ t + 1, then it must be specified as an argument to the SymbolicContinuousCallback or SymbolicDiscreteCallback constructor. Otherwise this warning can be disregarded."
266272
end
267273

268-
# Parameters in affect equations should become unknowns in the ImplicitDiscreteSystem.
269-
cb_params = Any[]
270-
discretes = Any[]
271-
p_as_dvs = Any[]
272-
for p in params
273-
if iscall(p) && (operation(p) isa Pre)
274-
push!(cb_params, p)
275-
elseif iscall(p) && length(arguments(p)) == 1 &&
276-
isequal(only(arguments(p)), iv)
277-
push!(discretes, p)
278-
push!(p_as_dvs, tovar(p))
279-
else
280-
push!(discretes, p)
281-
name = iscall(p) ? nameof(operation(p)) : nameof(p)
282-
p = wrap(Sym{FnType{Tuple{symtype(iv)}, Real}}(name)(iv))
283-
p = setmetadata(p, Symbolics.VariableSource, (:variables, name))
284-
push!(p_as_dvs, p)
285-
end
286-
end
287-
aff_map = Dict(zip(p_as_dvs, discretes))
288-
rev_map = Dict([v => k for (k, v) in aff_map])
289-
affect = Symbolics.substitute(affect, rev_map)
290-
@named affectsys = ImplicitDiscreteSystem(vcat(affect, algeeqs), iv, collect(union(dvs, p_as_dvs)), cb_params)
274+
pre_params = filter(haspre value, params)
275+
sys_params = setdiff(params, union(discrete_parameters, pre_params))
276+
discretes = map(tovar, discrete_parameters)
277+
aff_map = Dict(zip(discretes, discrete_parameters))
278+
@named affectsys = ImplicitDiscreteSystem(vcat(affect, algeeqs), iv, collect(union(dvs, discretes)), collect(union(pre_params, sys_params)))
291279
affectsys = complete(affectsys)
292280
# get accessed parameters p from Pre(p) in the callback parameters
293-
params = filter(isparameter, map(x -> unPre(x), cb_params))
281+
accessed_params = filter(isparameter, map(x -> unPre(x), cb_params))
282+
union!(accessed_params, sys_params)
294283
# add unknowns to the map
295284
for u in dvs
296285
aff_map[u] = u
297286
end
298287

299-
AffectSystem(affectsys, collect(dvs), params, discretes, aff_map, explicit)
288+
AffectSystem(affectsys, collect(dvs), collect(accessed_params), collect(discrete_parameters), aff_map, explicit)
300289
end
301290

302291
function make_affect(affect; kwargs...)
@@ -876,8 +865,8 @@ function compile_equational_affect(aff::Union{AffectSystem, Vector{Equation}}, s
876865
p_up, p_up! = build_function_wrapper(sys, (@view rhss[is_p]), dvs, _ps..., t; wrap_code = add_integrator_header(sys, integ, :p), expression = Val{false}, outputidxs = p_idxs, wrap_mtkparameters)
877866

878867
return function explicit_affect!(integ)
879-
u_up!(integ)
880-
p_up!(integ)
868+
isempty(dvs_to_update) || u_up!(integ)
869+
isempty(ps_to_update) || p_up!(integ)
881870
reset_jumps && reset_aggregated_jumps!(integ)
882871
end
883872
else
@@ -891,11 +880,12 @@ function compile_equational_affect(aff::Union{AffectSystem, Vector{Equation}}, s
891880
end
892881
u0 = Pair[]
893882
for u in unknowns(affsys)
894-
uval = isparameter(aff_map[u]) ? integ.ps[u] : integ[u]
883+
uval = isparameter(aff_map[u]) ? integ.ps[aff_map[u]] : integ[u]
895884
push!(u0, u => uval)
896885
end
897886
affprob = ImplicitDiscreteProblem(affsys, u0, (integ.t, integ.t), pmap; build_initializeprob = false, check_length = false)
898-
affsol = init(affprob, SimpleIDSolve())
887+
affsol = init(affprob, IDSolve())
888+
check_error(affsol) && throw(UnsolvableCallbackError(equations(affsys)))
899889
for u in dvs_to_update
900890
integ[u] = affsol[sys_map[u]]
901891
end
@@ -907,6 +897,14 @@ function compile_equational_affect(aff::Union{AffectSystem, Vector{Equation}}, s
907897
end
908898
end
909899

900+
struct UnsolvableCallbackError
901+
eqs::Vector{Equation}
902+
end
903+
904+
function Base.showerror(io, err::UnsolvableCallbackError)
905+
println(io, "The callback defined by the equations, $(join(err.eqs, "\n")), with discrete parameters is not solvable. Please check the algebraic equations, affect equations, and declared discrete parameters.")
906+
end
907+
910908
merge_cb(::Nothing, ::Nothing) = nothing
911909
merge_cb(::Nothing, x) = merge_cb(x, nothing)
912910
merge_cb(x, ::Nothing) = x

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ function generate_function(
281281
u_next = map(Shift(iv, 1), dvs)
282282
u = dvs
283283
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
284+
@show exprs
284285
build_function_wrapper(
285286
sys, exprs, u_next, u, p..., iv; p_start = 3, extra_assignments, kwargs...)
286287
end
@@ -375,6 +376,12 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
375376
f(u_next, u, p, t) = f_oop(u_next, u, p, t)
376377
f(resid, u_next, u, p, t) = f_iip(resid, u_next, u, p, t)
377378

379+
if length(dvs) == length(equations(sys))
380+
resid_prototype = nothing
381+
else
382+
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
383+
end
384+
378385
if specialize === SciMLBase.FunctionWrapperSpecialize && iip
379386
if u0 === nothing || p === nothing || t === nothing
380387
error("u0, p, and t must be specified for FunctionWrapperSpecialize on ImplicitDiscreteFunction.")
@@ -389,6 +396,7 @@ function SciMLBase.ImplicitDiscreteFunction{iip, specialize}(
389396
sys = sys,
390397
observed = observedfun,
391398
analytic = analytic,
399+
resid_prototype = resid_prototype,
392400
kwargs...)
393401
end
394402

test/jumpsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit, DiffEqBase, JumpProcesses, Test, LinearAlgebra
22
using Random, StableRNGs, NonlinearSolve
33
using OrdinaryDiffEq
44
using ModelingToolkit: t_nounits as t, D_nounits as D
5+
using BenchmarkTools
56
MT = ModelingToolkit
67

78
rng = StableRNG(12345)
@@ -79,7 +80,7 @@ function getmean(jprob, Nsims; use_stepper = true)
7980
end
8081
m / Nsims
8182
end
82-
m = getmean(jprob, Nsims)
83+
@btime m = $getmean($jprob, $Nsims)
8384

8485
# test auto-alg selection works
8586
jprobb = JumpProblem(js2, dprob; save_positions = (false, false), rng)

0 commit comments

Comments
 (0)