Skip to content

Commit 9257db2

Browse files
committed
some tests working
1 parent 1a7a29e commit 9257db2

File tree

8 files changed

+151
-107
lines changed

8 files changed

+151
-107
lines changed

src/linearization.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,6 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
434434
if !iszero(Bs)
435435
if !allow_input_derivatives
436436
der_inds = findall(vec(any(!iszero, Bs, dims = 1)))
437-
@show typeof(der_inds)
438437
error("Input derivatives appeared in expressions (-g_z\\g_u != 0), the following inputs appeared differentiated: $(ModelingToolkit.inputs(sys)[der_inds]). Call `linearize_symbolic` with keyword argument `allow_input_derivatives = true` to allow this and have the returned `B` matrix be of double width ($(2nu)), where the last $nu inputs are the derivatives of the first $nu inputs.")
439438
end
440439
B = [B [zeros(nx, nu); Bs]]

src/systems/callbacks.jl

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,31 @@ struct AffectSystem
6767
unknowns::Vector
6868
parameters::Vector
6969
discretes::Vector
70-
"""Maps the unknowns in the ImplicitDiscreteSystem to the corresponding parameter or unknown in the parent system."""
71-
affu_to_sysu::Dict
70+
"""Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
71+
aff_to_sys::Dict
7272
end
7373

7474
system(a::AffectSystem) = a.system
7575
discretes(a::AffectSystem) = a.discretes
7676
unknowns(a::AffectSystem) = a.unknowns
7777
parameters(a::AffectSystem) = a.parameters
78-
affu_to_sysu(a::AffectSystem) = a.affu_to_sysu
78+
aff_to_sys(a::AffectSystem) = a.aff_to_sys
79+
previous_vals(a::AffectSystem) = parameters(system(a))
80+
updated_vals(a::AffectSystem) = unknowns(system(a))
7981

8082
function Base.show(iio::IO, aff::AffectSystem)
8183
eqs = vcat(equations(system(aff)), observed(system(aff)))
8284
show(iio, eqs)
8385
end
8486

87+
function Base.:(==)(a1::AffectSystem, a2::AffectSystem)
88+
isequal(system(a1), system(a2)) &&
89+
isequal(discretes(a1), discretes(a2)) &&
90+
isequal(unknowns(a1), unknowns(a2)) &&
91+
isequal(parameters(a1), parameters(a2)) &&
92+
isequal(aff_to_sys(a1), aff_to_sys(a2))
93+
end
94+
8595
"""
8696
Pre(x)
8797
@@ -112,14 +122,14 @@ function (p::Pre)(x)
112122
iscall(x) && operation(x) isa Pre && return x
113123
result = if symbolic_type(x) == ArraySymbolic()
114124
# create an array for `Pre(array)`
115-
Symbolics.array_term(p, toparam(x))
125+
Symbolics.array_term(p, x)
116126
elseif iscall(x) && operation(x) == getindex
117127
# instead of `Pre(x[1])` create `Pre(x)[1]`
118128
# which allows parameter indexing to handle this case automatically.
119129
arr = arguments(x)[1]
120-
term(getindex, p(toparam(arr)), arguments(x)[2:end]...)
130+
term(getindex, p(arr), arguments(x)[2:end]...)
121131
else
122-
term(p, toparam(x))
132+
term(p, x)
123133
end
124134
# the result should be a parameter
125135
result = toparam(result)
@@ -231,25 +241,36 @@ function make_affect(affect::Vector{Equation}; warn = true)
231241
discretes = Any[]
232242
p_as_unknowns = Any[]
233243
for p in params
234-
if iscall(p) && (operator(p) isa Pre)
244+
if iscall(p) && (operation(p) isa Pre)
235245
push!(cb_params, p)
236246
elseif iscall(p) && length(arguments(p)) == 1 &&
237247
isequal(only(arguments(p)), iv)
238248
push!(discretes, p)
239249
push!(p_as_unknowns, tovar(p))
240250
else
241251
push!(discretes, p)
242-
p = iscall(p) ? wrap(Sym{FnType{Tuple{symtype(iv)}, Real}}(nameof(operation(p)))(iv)) :
243-
wrap(Sym{FnType{Tuple{symtype(iv)}, Real}}(nameof(p))(iv))
252+
name = iscall(p) ? nameof(operation(p)) : nameof(p)
253+
p = wrap(Sym{FnType{Tuple{symtype(iv)}, Real}}(name)(iv))
254+
p = setmetadata(p, Symbolics.VariableSource, (:variables, name))
244255
push!(p_as_unknowns, p)
245256
end
246257
end
258+
aff_map = Dict(zip(p_as_unknowns, discretes))
259+
rev_map = Dict([v => k for (k, v) in aff_map])
260+
affect = Symbolics.substitute(affect, rev_map)
247261
@mtkbuild affectsys = ImplicitDiscreteSystem(
248262
affect, iv, collect(union(unknowns, p_as_unknowns)), cb_params)
249-
params = map(x -> only(arguments(unwrap(x))), cb_params)
250-
affmap = Dict(zip([p_as_unknowns, unknowns], [discretes, unknowns]))
263+
params = filter(isparameter, map(x -> only(arguments(unwrap(x))), cb_params))
264+
@show params
265+
266+
for u in unknowns
267+
aff_map[u] = u
268+
end
269+
270+
@show unknowns
271+
@show params
251272

252-
return AffectSystem(affectsys, collect(unknowns), params, discretes, affmap)
273+
return AffectSystem(affectsys, collect(unknowns), params, discretes, aff_map)
253274
end
254275

255276
function make_affect(affect)
@@ -393,17 +414,19 @@ function SymbolicDiscreteCallbacks(events, algeeqs::Vector{Equation} = Equation[
393414

394415
for event in events
395416
cond, affs = event isa Pair ? (event[1], event[2]) : (event, nothing)
396-
if aff isa AbstractVector
397-
aff = vcat(aff, algeeqs)
417+
if affs isa AbstractVector
418+
affs = vcat(affs, algeeqs)
398419
end
399-
affect = make_affect(aff)
400-
push!(callbacks, SymbolicDiscreteCallback(cond, affect, nothing, nothing))
420+
affect = make_affect(affs)
421+
push!(callbacks, SymbolicDiscreteCallback(cond, affect))
401422
end
402423
callbacks
403424
end
404425

405426
function is_timed_condition(condition::T) where {T}
406-
if T <: Real
427+
if T === Num
428+
false
429+
elseif T <: Real
407430
true
408431
elseif T <: AbstractVector
409432
eltype(condition) <: Real
@@ -582,23 +605,31 @@ function compile_condition(cbs::Union{AbstractCallback, Vector{<:AbstractCallbac
582605
condit = substitute(condit, cmap)
583606
end
584607

585-
f_oop, f_iip = build_function_wrapper(sys,
586-
condit, u, t, p...; expression = Val{true},
587-
p_start = 3, p_end = length(p) + 2,
608+
if !is_discrete(cbs)
609+
condit = [cond.lhs - cond.rhs for cond in condit]
610+
end
611+
612+
fs = build_function_wrapper(sys,
613+
condit, u, p..., t; expression,
588614
kwargs...)
589615

590-
if cbs isa AbstractVector
591-
cond(out, u, t, integ) = f_iip(out, u, t, parameter_values(integ))
616+
if expression == Val{true}
617+
fs = eval_or_rgf.(fs; eval_expression, eval_module)
618+
end
619+
is_discrete(cbs) ? (f_oop = fs) : (f_oop, f_iip = fs)
620+
621+
cond = if cbs isa AbstractVector
622+
(out, u, t, integ) -> f_iip(out, u, parameter_values(integ), t)
592623
elseif is_discrete(cbs)
593-
cond(u, t, integ) = f_oop(u, t, parameter_values(integ))
624+
(u, t, integ) -> f_oop(u, parameter_values(integ), t)
594625
else
595-
cond = function (u, t, integ)
626+
function (u, t, integ)
596627
if DiffEqBase.isinplace(integ.sol.prob)
597628
tmp, = DiffEqBase.get_tmp_cache(integ)
598-
f_iip(tmp, u, t, parameter_values(integ))
629+
f_iip(tmp, u, parameter_values(integ), t)
599630
tmp[1]
600631
else
601-
f_oop(u, t, parameter_values(integ))
632+
f_oop(u, parameter_values(integ), t)
602633
end
603634
end
604635
end
@@ -641,6 +672,7 @@ function compile_functional_affect(affect::FunctionalAffect, cb, sys, dvs, ps; k
641672
end
642673

643674
is_discrete(cb::AbstractCallback) = cb isa SymbolicDiscreteCallback
675+
is_discrete(cb::Vector{<:AbstractCallback}) = eltype(cb) isa SymbolicDiscreteCallback
644676

645677
function generate_continuous_callbacks(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys; initial_parameters = true); kwargs...)
646678
cbs = continuous_events(sys)
@@ -668,27 +700,27 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
668700
return generate_callback(cbs[cb_ind], sys; kwargs...)
669701
end
670702

671-
trigger = compile_condition(cbs, sys, dvs, ps; kwargs...)
703+
trigger = compile_condition(cbs, sys, unknowns(sys), parameters(sys; initial_parameters = true); kwargs...)
672704
affects = []
673705
affect_negs = []
674706
inits = []
675707
finals = []
676708
for cb in cbs
677-
affect = compile_affect(cb.affect, cb, sys)
709+
affect = compile_affect(cb.affect, cb, sys, default = (args...) -> ())
678710

679711
push!(affects, affect)
680712
push!(affect_negs, compile_affect(cb.affect_neg, cb, sys, default = affect))
681-
push!(inits, compile_affect(cb.initialize, cb, sys, default = SciMLBase.INITALIZE_DEFAULT))
682-
push!(finals, compile_affect(cb.finalize, cb, sys, default = SciMLBase.FINALIZE_DEFAULT))
713+
push!(inits, compile_affect(cb.initialize, cb, sys, default = nothing))
714+
push!(finals, compile_affect(cb.finalize, cb, sys, default = nothing))
683715
end
684716

685717
# Since there may be different number of conditions and affects,
686718
# we build a map that translates the condition eq. number to the affect number
687-
num_eqs = length.(eqs)
688719
eq2affect = reduce(vcat,
689720
[fill(i, num_eqs[i]) for i in eachindex(affects)])
721+
eqs = reduce(vcat, eqs)
690722
@assert length(eq2affect) == length(eqs)
691-
@assert maximum(eq2affect) == length(affect_functions)
723+
@assert maximum(eq2affect) == length(affects)
692724

693725
affect = function (integ, idx)
694726
affects[eq2affect[idx]](integ)
@@ -702,8 +734,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
702734
finalize = compile_vector_optional_affect(finals, SciMLBase.FINALIZE_DEFAULT)
703735

704736
return VectorContinuousCallback(
705-
trigger, affect, length(cbs); affect_neg, initialize, finalize,
706-
rootfind = callback.rootfind, initializealg = SciMLBase.NoInit)
737+
trigger, affect, affect_neg, length(eqs); initialize, finalize,
738+
rootfind = cbs[1].rootfind, initializealg = SciMLBase.NoInit)
707739
end
708740

709741
function generate_callback(cb, sys; kwargs...)
@@ -712,14 +744,14 @@ function generate_callback(cb, sys; kwargs...)
712744
ps = parameters(sys; initial_parameters = true)
713745

714746
trigger = is_timed ? conditions(cb) : compile_condition(cb, sys, dvs, ps; kwargs...)
715-
affect = compile_affect(cb.affect, cb, sys)
747+
affect = compile_affect(cb.affect, cb, sys, default = (args...) -> ())
716748
affect_neg = hasfield(typeof(cb), :affect_neg) ?
717749
compile_affect(cb.affect_neg, cb, sys, default = affect) : nothing
718750
initialize = compile_affect(cb.initialize, cb, sys, default = SciMLBase.INITIALIZE_DEFAULT)
719751
finalize = compile_affect(cb.finalize, cb, sys, default = SciMLBase.FINALIZE_DEFAULT)
720752

721753
if is_discrete(cb)
722-
if is_timed && condition(cb) isa AbstractVector
754+
if is_timed && conditions(cb) isa AbstractVector
723755
return PresetTimeCallback(trigger, affect; affect_neg, initialize,
724756
finalize, initializealg = SciMLBase.NoInit)
725757
elseif is_timed
@@ -762,22 +794,30 @@ function compile_affect(
762794

763795
ps = parameters(aff)
764796
dvs = unknowns(aff)
797+
@show ps
765798

766799
if aff isa AffectSystem
767-
aff_map = affu_to_sysu(aff)
800+
aff_map = aff_to_sys(aff)
801+
sys_map = Dict([v => k for (k, v) in aff_map])
802+
build_initializeprob = has_alg_eqs(sys)
803+
768804
function affect!(integrator)
769-
pmap = []
770-
for pre_p in parameters(system(affect))
805+
pmap = Pair[]
806+
for pre_p in previous_vals(aff)
771807
p = only(arguments(unwrap(pre_p)))
772-
push!(pmap, pre_p => integrator[p])
773-
end
774-
guesses = [u => integrator[aff_map[u]] for u in unknowns(system(affect))]
775-
prob = ImplicitDiscreteProblem(system(affect), [], (0, 1), pmap; guesses)
776-
sol = init(prob, SimpleIDSolve())
777-
for u in unknowns(system(affect))
778-
integrator[aff_map[u]] = sol[u]
808+
pval = isparameter(p) ? integrator.ps[p] : integrator[p]
809+
push!(pmap, pre_p => pval)
779810
end
811+
guesses = Pair[u => integrator[aff_map[u]] for u in updated_vals(aff)]
812+
affprob = ImplicitDiscreteProblem(system(aff), Pair[], (0, 1), pmap; guesses, build_initializeprob)
780813

814+
affsol = init(affprob, SimpleIDSolve())
815+
for u in unknowns(aff)
816+
integrator[u] = affsol[u]
817+
end
818+
for p in discretes(aff)
819+
integrator.ps[p] = affsol[sys_map[p]]
820+
end
781821
for idx in save_idxs
782822
SciMLBase.save_discretes!(integ, idx)
783823
end

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
311311
cons = get_constraintsystem(sys)
312312
cons !== nothing && push!(conssystems, cons)
313313
end
314-
@show conssystems
315314
@set! constraintsystem.systems = conssystems
316315
end
317316

src/systems/discrete_system/discrete_system.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,13 @@ end
422422
function DiscreteFunctionExpr(sys::DiscreteSystem, args...; kwargs...)
423423
DiscreteFunctionExpr{true}(sys, args...; kwargs...)
424424
end
425+
426+
function Base.:(==)(sys1::DiscreteSystem, sys2::DiscreteSystem)
427+
sys1 === sys2 && return true
428+
isequal(nameof(sys1), nameof(sys2)) &&
429+
isequal(get_iv(sys1), get_iv(sys2)) &&
430+
_eq_unordered(get_eqs(sys1), get_eqs(sys2)) &&
431+
_eq_unordered(get_unknowns(sys1), get_unknowns(sys2)) &&
432+
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
433+
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
434+
end

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,17 @@ function generate_function(
268268
iv = get_iv(sys)
269269
# Algebraic equations get shifted forward 1, to match with differential equations
270270
exprs = map(equations(sys)) do eq
271-
_iszero(eq.lhs) ? distribute_shift(Next(eq.rhs)) : (eq.rhs - eq.lhs)
271+
_iszero(eq.lhs) ? distribute_shift(Shift(iv, 1)(eq.rhs)) : (eq.rhs - eq.lhs)
272272
end
273273

274274
# Handle observables in algebraic equations, since they are shifted
275275
obs = observed(sys)
276-
shifted_obs = Symbolics.Equation[distribute_shift(Next(eq)) for eq in obs]
276+
shifted_obs = Symbolics.Equation[distribute_shift(Shift(iv, 1)(eq)) for eq in obs]
277277
obsidxs = observed_equations_used_by(sys, exprs; obs = shifted_obs)
278278
extra_assignments = [Assignment(shifted_obs[i].lhs, shifted_obs[i].rhs)
279279
for i in obsidxs]
280280

281-
u_next = map(Next, dvs)
281+
u_next = map(Shift(iv, 1), dvs)
282282
u = dvs
283283
build_function_wrapper(
284284
sys, exprs, u_next, u, ps..., iv; p_start = 3, extra_assignments, kwargs...)
@@ -334,10 +334,8 @@ function SciMLBase.ImplicitDiscreteProblem(
334334

335335
u0map = to_varmap(u0map, dvs)
336336
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
337-
@show u0map
338337
f, u0, p = process_SciMLProblem(
339338
ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...)
340-
@show u0
341339

342340
kwargs = filter_kwargs(kwargs)
343341
ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)
@@ -438,3 +436,13 @@ end
438436
function ImplicitDiscreteFunctionExpr(sys::ImplicitDiscreteSystem, args...; kwargs...)
439437
ImplicitDiscreteFunctionExpr{true}(sys, args...; kwargs...)
440438
end
439+
440+
function Base.:(==)(sys1::ImplicitDiscreteSystem, sys2::ImplicitDiscreteSystem)
441+
sys1 === sys2 && return true
442+
isequal(nameof(sys1), nameof(sys2)) &&
443+
isequal(get_iv(sys1), get_iv(sys2)) &&
444+
_eq_unordered(get_eqs(sys1), get_eqs(sys2)) &&
445+
_eq_unordered(get_unknowns(sys1), get_unknowns(sys2)) &&
446+
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
447+
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
448+
end

src/systems/index_cache.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ function IndexCache(sys::AbstractSystem)
115115
affs = [affs]
116116
end
117117
for affect in affs
118-
if affect isa Equation
119-
is_parameter(sys, affect.lhs) && push!(discs, affect.lhs)
120-
elseif affect isa FunctionalAffect || affect isa ImperativeAffect
118+
if affect isa AffectSystem || affect isa FunctionalAffect || affect isa ImperativeAffect
121119
union!(discs, unwrap.(discretes(affect)))
122120
elseif isnothing(affect)
123121
continue

src/systems/problem_utils.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ Keyword arguments:
346346
function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
347347
tofloat = true, use_union = true, container_type = Array,
348348
toterm = default_toterm, promotetoconcrete = nothing, check = true, allow_symbolic = false)
349-
isempty(vars) && return nothing
349+
isempty(vars) && return Float64[]
350350

351351
if check
352352
missing_vars = missingvars(varmap, vars; toterm)
@@ -369,9 +369,7 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
369369
vals = promote_to_concrete(vals; tofloat = tofloat, use_union = use_union)
370370
end
371371

372-
if isempty(vals)
373-
return nothing
374-
elseif container_type <: Tuple
372+
if container_type <: Tuple
375373
return (vals...,)
376374
else
377375
return SymbolicUtils.Code.create_array(container_type, eltype(vals), Val{1}(),

0 commit comments

Comments
 (0)