Skip to content

Commit b9fc098

Browse files
committed
up
1 parent f852656 commit b9fc098

File tree

4 files changed

+410
-470
lines changed

4 files changed

+410
-470
lines changed

src/systems/callbacks.jl

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ struct AffectSystem
6262
discretes::Vector
6363
"""Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
6464
aff_to_sys::Dict
65-
explicit::Bool
6665
end
6766

6867
system(a::AffectSystem) = a.system
@@ -71,11 +70,11 @@ unknowns(a::AffectSystem) = a.unknowns
7170
parameters(a::AffectSystem) = a.parameters
7271
aff_to_sys(a::AffectSystem) = a.aff_to_sys
7372
previous_vals(a::AffectSystem) = parameters(system(a))
74-
is_explicit(a::AffectSystem) = a.explicit
73+
all_equations(a::AffectSystem) = vcat(equations(system(a)), observed(system(a)))
7574

7675
function Base.show(iio::IO, aff::AffectSystem)
7776
println(iio, "Affect system defined by equations:")
78-
eqs = vcat(equations(system(aff)), observed(system(aff)))
77+
eqs = all_equations(aff)
7978
show(iio, eqs)
8079
end
8180

@@ -84,17 +83,15 @@ function Base.:(==)(a1::AffectSystem, a2::AffectSystem)
8483
isequal(discretes(a1), discretes(a2)) &&
8584
isequal(unknowns(a1), unknowns(a2)) &&
8685
isequal(parameters(a1), parameters(a2)) &&
87-
isequal(aff_to_sys(a1), aff_to_sys(a2)) &&
88-
isequal(is_explicit(a1), is_explicit(a2))
86+
isequal(aff_to_sys(a1), aff_to_sys(a2))
8987
end
9088

9189
function Base.hash(a::AffectSystem, s::UInt)
9290
s = hash(system(a), s)
9391
s = hash(unknowns(a), s)
9492
s = hash(parameters(a), s)
9593
s = hash(discretes(a), s)
96-
s = hash(aff_to_sys(a), s)
97-
hash(is_explicit(a), s)
94+
hash(aff_to_sys(a), s)
9895
end
9996

10097
function vars!(vars, aff::Union{FunctionalAffect, AffectSystem}; op = Differential)
@@ -241,51 +238,44 @@ make_affect(affect::Tuple; kwargs...) = FunctionalAffect(affect...)
241238
make_affect(affect::NamedTuple; kwargs...) = FunctionalAffect(; affect...)
242239
make_affect(affect::Affect; kwargs...) = affect
243240

244-
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[], iv = nothing, algeeqs::Vector{Equation} = Equation[])
241+
function make_affect(affect::Vector{Equation}; discrete_parameters::AbstractVector = Any[], iv = nothing, algeeqs::Vector{Equation} = Equation[])
245242
isempty(affect) && return nothing
246243
isempty(algeeqs) && @warn "No algebraic equations were found for the callback defined by $(join(affect, ", ")). If the system has no algebraic equations, this can be disregarded. Otherwise pass in `algeeqs` to the SymbolicContinuousCallback constructor."
244+
isnothing(iv) && error("Must specify iv.")
247245

248-
for p in discretes
249-
# Check if p is time-dependent
250-
false && error("Non-time dependent parameter $p passed in as a discrete. Must be declared as $p(t).")
246+
for p in discrete_parameters
247+
occursin(unwrap(iv), unwrap(p)) || error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
251248
end
252249

253-
explicit = true
254250
dvs = OrderedSet()
255251
params = OrderedSet()
256252
for eq in affect
257-
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic())
253+
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() || symbolic_type(eq.lhs) === NotSymbolic())
258254
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x)."
259-
explicit = false
260255
end
261256
collect_vars!(dvs, params, eq, iv; op = Pre)
262257
end
263258
for eq in algeeqs
264259
collect_vars!(dvs, params, eq, iv)
265-
explicit = false
266-
end
267-
any(isirreducible, dvs) && (explicit = false)
268-
269-
if isnothing(iv)
270-
iv = isempty(dvs) ? iv : only(arguments(dvs[1]))
271-
isnothing(iv) && @warn "No independent variable specified and could not be inferred. If the iv appears in an affect equation explicitly, like x ~ t + 1, then it must be specified as an argument to the SymbolicContinuousCallback or SymbolicDiscreteCallback constructor. Otherwise this warning can be disregarded."
272260
end
273261

274262
pre_params = filter(haspre value, params)
275-
sys_params = setdiff(params, union(discrete_parameters, pre_params))
263+
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
276264
discretes = map(tovar, discrete_parameters)
277265
aff_map = Dict(zip(discretes, discrete_parameters))
278-
@named affectsys = ImplicitDiscreteSystem(vcat(affect, algeeqs), iv, collect(union(dvs, discretes)), collect(union(pre_params, sys_params)))
279-
affectsys = complete(affectsys)
266+
rev_map = Dict(zip(discrete_parameters, discretes))
267+
affect = Symbolics.fast_substitute(affect, rev_map)
268+
algeeqs = Symbolics.fast_substitute(algeeqs, rev_map)
269+
@mtkbuild affectsys = ImplicitDiscreteSystem(vcat(affect, algeeqs), iv, collect(union(dvs, discretes)), collect(union(pre_params, sys_params)))
280270
# get accessed parameters p from Pre(p) in the callback parameters
281-
accessed_params = filter(isparameter, map(x -> unPre(x), cb_params))
271+
accessed_params = filter(isparameter, map(unPre, collect(pre_params)))
282272
union!(accessed_params, sys_params)
283273
# add unknowns to the map
284274
for u in dvs
285275
aff_map[u] = u
286276
end
287277

288-
AffectSystem(affectsys, collect(dvs), collect(accessed_params), collect(discrete_parameters), aff_map, explicit)
278+
AffectSystem(affectsys, collect(dvs), collect(accessed_params), collect(discrete_parameters), aff_map)
289279
end
290280

291281
function make_affect(affect; kwargs...)
@@ -295,7 +285,7 @@ end
295285
"""
296286
Generate continuous callbacks.
297287
"""
298-
function SymbolicContinuousCallbacks(events; algeeqs::Vector{Equation} = Equation[], iv = nothing)
288+
function SymbolicContinuousCallbacks(events; discrete_parameters = Any[], algeeqs::Vector{Equation} = Equation[], iv = nothing)
299289
callbacks = SymbolicContinuousCallback[]
300290
isnothing(events) && return callbacks
301291

@@ -304,7 +294,7 @@ function SymbolicContinuousCallbacks(events; algeeqs::Vector{Equation} = Equatio
304294

305295
for event in events
306296
cond, affs = event isa Pair ? (event[1], event[2]) : (event, nothing)
307-
push!(callbacks, SymbolicContinuousCallback(cond, affs; iv, algeeqs))
297+
push!(callbacks, SymbolicContinuousCallback(cond, affs; iv, algeeqs, discrete_parameters))
308298
end
309299
callbacks
310300
end
@@ -412,11 +402,11 @@ struct SymbolicDiscreteCallback <: AbstractCallback
412402

413403
function SymbolicDiscreteCallback(
414404
condition, affect = nothing;
415-
initialize = nothing, finalize = nothing, iv = nothing, algeeqs = Equation[])
405+
initialize = nothing, finalize = nothing, iv = nothing, algeeqs = Equation[], discrete_parameters = Any[])
416406
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
417407

418-
new(c, make_affect(affect; iv, algeeqs), make_affect(initialize; iv, algeeqs),
419-
make_affect(finalize; iv, algeeqs))
408+
new(c, make_affect(affect; iv, algeeqs, discrete_parameters), make_affect(initialize; iv, algeeqs, discrete_parameters),
409+
make_affect(finalize; iv, algeeqs, discrete_parameters))
420410
end # Default affect to nothing
421411
end
422412

@@ -426,7 +416,7 @@ SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback, args...; kwargs...) = cb
426416
"""
427417
Generate discrete callbacks.
428418
"""
429-
function SymbolicDiscreteCallbacks(events; algeeqs::Vector{Equation} = Equation[], iv = nothing)
419+
function SymbolicDiscreteCallbacks(events; discrete_parameters::Vector = Any[], algeeqs::Vector{Equation} = Equation[], iv = nothing)
430420
callbacks = SymbolicDiscreteCallback[]
431421

432422
isnothing(events) && return callbacks
@@ -435,7 +425,7 @@ function SymbolicDiscreteCallbacks(events; algeeqs::Vector{Equation} = Equation[
435425

436426
for event in events
437427
cond, affs = event isa Pair ? (event[1], event[2]) : (event, nothing)
438-
push!(callbacks, SymbolicDiscreteCallback(cond, affs; iv, algeeqs))
428+
push!(callbacks, SymbolicDiscreteCallback(cond, affs; iv, algeeqs, discrete_parameters))
439429
end
440430
callbacks
441431
end
@@ -471,7 +461,7 @@ function namespace_affects(affect::AffectSystem, s)
471461
renamespace.((s,), unknowns(affect)),
472462
renamespace.((s,), parameters(affect)),
473463
renamespace.((s,), discretes(affect)),
474-
Dict([k => renamespace(s, v) for (k, v) in aff_to_sys(affect)]), is_explicit(affect))
464+
Dict([k => renamespace(s, v) for (k, v) in aff_to_sys(affect)]))
475465
end
476466
namespace_affects(af::Nothing, s) = nothing
477467

@@ -837,19 +827,17 @@ function compile_equational_affect(aff::Union{AffectSystem, Vector{Equation}}, s
837827
aff_map = aff_to_sys(aff)
838828
sys_map = Dict([v => k for (k, v) in aff_map])
839829

840-
if is_explicit(aff)
841-
affsys = structural_simplify(affsys)
842-
@assert isempty(equations(affsys))
830+
if isempty(equations(affsys))
843831
update_eqs = Symbolics.fast_substitute(observed(affsys), Dict([p => unPre(p) for p in parameters(affsys)]))
844832
rhss = map(x -> x.rhs, update_eqs)
845833
lhss = map(x -> aff_map[x.lhs], update_eqs)
846834
is_p = [lhs Set(ps_to_update) for lhs in lhss]
847-
835+
is_u = [lhs Set(dvs_to_update) for lhs in lhss]
848836
dvs = unknowns(sys)
849837
ps = parameters(sys)
850838
t = get_iv(sys)
851839

852-
u_idxs = indexin((@view lhss[.!is_p]), dvs)
840+
u_idxs = indexin((@view lhss[is_u]), dvs)
853841

854842
wrap_mtkparameters = has_index_cache(sys) && (get_index_cache(sys) !== nothing)
855843
p_idxs = if wrap_mtkparameters
@@ -861,7 +849,7 @@ function compile_equational_affect(aff::Union{AffectSystem, Vector{Equation}}, s
861849
_ps = reorder_parameters(sys, ps)
862850
integ = gensym(:MTKIntegrator)
863851

864-
u_up, u_up! = build_function_wrapper(sys, (@view rhss[.!is_p]), dvs, _ps..., t; wrap_code = add_integrator_header(sys, integ, :u), expression = Val{false}, outputidxs = u_idxs, wrap_mtkparameters)
852+
u_up, u_up! = build_function_wrapper(sys, (@view rhss[is_u]), dvs, _ps..., t; wrap_code = add_integrator_header(sys, integ, :u), expression = Val{false}, outputidxs = u_idxs, wrap_mtkparameters)
865853
p_up, p_up! = build_function_wrapper(sys, (@view rhss[is_p]), dvs, _ps..., t; wrap_code = add_integrator_header(sys, integ, :p), expression = Val{false}, outputidxs = p_idxs, wrap_mtkparameters)
866854

867855
return function explicit_affect!(integ)
@@ -870,7 +858,7 @@ function compile_equational_affect(aff::Union{AffectSystem, Vector{Equation}}, s
870858
reset_jumps && reset_aggregated_jumps!(integ)
871859
end
872860
else
873-
return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_update = ps_to_update
861+
return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map, affsys = affsys, ps_to_update = ps_to_update, aff = aff
874862
function implicit_affect!(integ)
875863
pmap = Pair[]
876864
for pre_p in parameters(affsys)
@@ -885,7 +873,7 @@ function compile_equational_affect(aff::Union{AffectSystem, Vector{Equation}}, s
885873
end
886874
affprob = ImplicitDiscreteProblem(affsys, u0, (integ.t, integ.t), pmap; build_initializeprob = false, check_length = false)
887875
affsol = init(affprob, IDSolve())
888-
check_error(affsol) && throw(UnsolvableCallbackError(equations(affsys)))
876+
(check_error(affsol) === ReturnCode.InitialFailure) && throw(UnsolvableCallbackError(all_equations(aff)))
889877
for u in dvs_to_update
890878
integ[u] = affsol[sys_map[u]]
891879
end
@@ -901,8 +889,8 @@ struct UnsolvableCallbackError
901889
eqs::Vector{Equation}
902890
end
903891

904-
function Base.showerror(io, err::UnsolvableCallbackError)
905-
println(io, "The callback defined by the equations, $(join(err.eqs, "\n")), with discrete parameters is not solvable. Please check the algebraic equations, affect equations, and declared discrete parameters.")
892+
function Base.showerror(io::IO, err::UnsolvableCallbackError)
893+
println(io, "The callback defined by the following equations:\n\n$(join(err.eqs, "\n"))\n\nis not solvable. Please check that the algebraic equations and affect equations are correct, and that all parameters intended to be changed are passed in as `discrete_parameters`.")
906894
end
907895

908896
merge_cb(::Nothing, ::Nothing) = nothing

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ function generate_function(
281281
u_next = map(Shift(iv, 1), dvs)
282282
u = dvs
283283
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
284-
@show exprs
285284
build_function_wrapper(
286285
sys, exprs, u_next, u, p..., iv; p_start = 3, extra_assignments, kwargs...)
287286
end

src/systems/index_cache.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ function IndexCache(sys::AbstractSystem)
127127
end
128128

129129
for sym in discs
130+
@show sym
130131
is_parameter(sys, sym) ||
131132
error("Expected discrete variable $sym in callback to be a parameter")
132133

0 commit comments

Comments
 (0)