Skip to content

Commit 1c0c328

Browse files
YingboMabaggepinnen
andcommitted
Better API for discrete controller simulation
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
1 parent 455ae0a commit 1c0c328

File tree

6 files changed

+70
-48
lines changed

6 files changed

+70
-48
lines changed

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ for prop in [:eqs
208208
:torn_matching
209209
:tearing_state
210210
:substitutions
211-
:metadata]
211+
:metadata
212+
:discrete_subsystems]
212213
fname1 = Symbol(:get_, prop)
213214
fname2 = Symbol(:has_, prop)
214215
@eval begin

src/systems/clock_inference.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,22 @@ function split_system(ci::ClockInference)
147147
@set! ts_i.structure.eq_to_diff = eq_to_diff
148148
tss[id] = ts_i
149149
end
150-
return tss, inputs, continuous_id
150+
return tss, inputs, continuous_id, id_to_clock
151151
end
152152

153-
function generate_discrete_affect(syss, inputs, continuous_id; checkbounds = true,
153+
function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
154+
checkbounds = true,
154155
eval_module = @__MODULE__, eval_expression = true)
155156
out = Sym{Any}(:out)
156157
appended_parameters = parameters(syss[continuous_id])
157158
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
158159
offset = length(appended_parameters)
159160
affect_funs = []
160161
svs = []
162+
clocks = TimeDomain[]
161163
for (i, (sys, input)) in enumerate(zip(syss, inputs))
162164
i == continuous_id && continue
165+
push!(clocks, id_to_clock[i])
163166
subs = get_substitutions(sys)
164167
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
165168
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
@@ -191,9 +194,9 @@ function generate_discrete_affect(syss, inputs, continuous_id; checkbounds = tru
191194
cont_to_disc_idxs = (offset + 1):(offset += ni)
192195
input_offset = offset
193196
disc_range = (offset + 1):(offset += ns)
194-
save_tuple = Expr(:tuple)
197+
save_vec = Expr(:ref, :Float64)
195198
for i in 1:ns
196-
push!(save_tuple.args, :(p[$(input_offset + i)]))
199+
push!(save_vec.args, :(p[$(input_offset + i)]))
197200
end
198201
affect! = :(function (integrator, saved_values)
199202
@unpack u, p, t = integrator
@@ -208,10 +211,10 @@ function generate_discrete_affect(syss, inputs, continuous_id; checkbounds = tru
208211
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
209212
copyto!(d2c_view, d2c_obs(disc_state, p, t))
210213
push!(saved_values.t, t)
211-
push!(saved_values.saveval, $save_tuple)
214+
push!(saved_values.saveval, $save_vec)
212215
disc(disc_state, disc_state, p, t)
213216
end)
214-
sv = SavedValues(Float64, NTuple{ns, Float64})
217+
sv = SavedValues(Float64, Vector{Float64})
215218
push!(affect_funs, affect!)
216219
push!(svs, sv)
217220
end
@@ -223,5 +226,5 @@ function generate_discrete_affect(syss, inputs, continuous_id; checkbounds = tru
223226
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
224227
end
225228
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
226-
return affects, svs, appended_parameters, defaults
229+
return affects, clocks, svs, appended_parameters, defaults
227230
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,12 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...
686686
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
687687
end
688688

689+
struct DiscreteSaveAffect{F, S} <: Function
690+
f::F
691+
s::S
692+
end
693+
(d::DiscreteSaveAffect)(args...) = d.f(args..., d.s)
694+
689695
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
690696
tspan = get_tspan(sys),
691697
parammap = DiffEqBase.NullParameters();
@@ -698,13 +704,35 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
698704
has_difference = has_difference,
699705
check_length, kwargs...)
700706
cbs = process_events(sys; callback, has_difference, kwargs...)
707+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
708+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
709+
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
710+
if clock isa Clock
711+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
712+
else
713+
error("$clock is not a supported clock type.")
714+
end
715+
end
716+
if cbs === nothing
717+
if length(discrete_cbs) == 1
718+
cbs = only(discrete_cbs)
719+
else
720+
cbs = CallbackSet(discrete_cbs...)
721+
end
722+
else
723+
cbs = CallbackSet(cbs, discrete_cbs)
724+
end
725+
else
726+
svs = nothing
727+
end
701728
kwargs = filter_kwargs(kwargs)
702729
pt = something(get_metadata(sys), StandardODEProblem())
703730

704731
if cbs === nothing
705-
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs...)
732+
ODEProblem{iip}(f, u0, tspan, p, pt; disc_saved_values = svs, kwargs...)
706733
else
707-
ODEProblem{iip}(f, u0, tspan, p, pt; callback = cbs, kwargs...)
734+
ODEProblem{iip}(f, u0, tspan, p, pt; callback = cbs, disc_saved_values = svs,
735+
kwargs...)
708736
end
709737
end
710738
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct ODESystem <: AbstractODESystem
129129
"""
130130
discrete_subsystems: a list of discrete subsystems
131131
"""
132-
discrete_subsystems::Union{Nothing, Tuple{Vector{ODESystem}, Vector{Any}, Int}}
132+
discrete_subsystems::Any
133133

134134
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
135135
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,

src/systems/systemstructure.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -467,24 +467,33 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
467467
if state.sys isa ODESystem
468468
ci = ModelingToolkit.ClockInference(state)
469469
ModelingToolkit.infer_clocks!(ci)
470-
tss, inputs, continuous_id = ModelingToolkit.split_system(ci)
470+
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
471471
cont_io = merge_io(io, inputs[continuous_id])
472-
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify, check_consistency,
473-
kwargs...)
472+
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
473+
check_consistency,
474+
kwargs...)
474475
if length(tss) > 1
475476
# TODO: rename it to something else
476477
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
478+
# Note that the appended_parameters must agree with
479+
# `generate_discrete_affect`!
480+
appended_parameters = parameters(sys)
477481
for (i, state) in enumerate(tss)
478482
if i == continuous_id
479483
discrete_subsystems[i] = sys
480484
continue
481485
end
482486
dist_io = merge_io(io, inputs[i])
483-
ss = _structural_simplify!(state, dist_io; simplify, check_consistency,
484-
kwargs...)
485-
push!(discrete_subsystems, ss)
487+
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
488+
kwargs...)
489+
append!(appended_parameters, inputs[i], states(ss))
490+
discrete_subsystems[i] = ss
486491
end
487-
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id
492+
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
493+
id_to_clock
494+
@set! sys.ps = appended_parameters
495+
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
496+
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
488497
end
489498
else
490499
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,

test/clock.jl

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,30 +112,9 @@ eqs = [yd ~ Sample(t, dt)(y)
112112
=#
113113
]
114114
@named sys = ODESystem(eqs)
115-
ci, varmap = infer_clocks(sys)
116-
tss, inputs, continuous_id = ModelingToolkit.split_system(deepcopy(ci))
117-
syss = map(i -> SystemStructures._structural_simplify!(deepcopy(tss[i]), (inputs[i], ()))[1],
118-
eachindex(tss))
119-
sys1, sys2 = syss
120-
@test length(states(sys2)) == 2
121-
z, z_t = states(sys2)
122-
S = Shift(t, 1)
123-
@test full_equations(sys2) == [S(z) ~ z_t; S(z_t) ~ z + Sample(t, dt)(y)]
124-
# TODO: set Hold(ud)
125-
affects, svs, pp, defaults = ModelingToolkit.generate_discrete_affect(syss, inputs,
126-
continuous_id);
127-
@set! sys1.ps = pp
128-
prob = ODEProblem(sys1, [x => 0.0, y => 0.0], (0.0, 1.0), [pp .=> 0.0; kp => 1.0]);
129-
using OrdinaryDiffEq, DiffEqCallbacks
130-
struct DiscreteSaveAffect{F, S} <: Function
131-
f::F
132-
s::S
133-
end
134-
(d::DiscreteSaveAffect)(args...) = d.f(args..., d.s)
135-
gen_affect! = DiscreteSaveAffect(affects[1], svs[1]);
136-
cb = PeriodicCallback(gen_affect!, 0.1);
137-
prob = remake(prob, callback = cb);
138-
sol2 = solve(prob, Tsit5());
115+
ss = structural_simplify(sys)
116+
prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, 1.0), [kp => 1.0; z => 0.0; D(z) => 0.0])
117+
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
139118
# For all inputs in parameters, just initialize them to 0.0, and then set them
140119
# in the callback.
141120

@@ -153,17 +132,19 @@ function affect!(integrator, saved_values)
153132
r = 1.0
154133
ud = kp * (r - yd) + z
155134
push!(saved_values.t, integrator.t)
156-
push!(saved_values.saveval, (integrator.p[3], integrator.p[4]))
135+
push!(saved_values.saveval, [integrator.p[4], integrator.p[3]])
157136
integrator.p[2] = ud
158137
integrator.p[3] = z + yd
159138
integrator.p[4] = z_t
160139
nothing
161140
end
162-
saved_values = SavedValues(Float64, Tuple{Float64, Float64});
163-
cb = PeriodicCallback(Base.Fix2(affect!, saved_values), 0.1);
164-
prob = ODEProblem(foo!, [0.0], (0.0, 1.0), [1.0, 0.0, 0.0, 0.0], callback = cb);
165-
sol = solve(prob, Tsit5());
166-
@test sol.u sol2.u
141+
saved_values = SavedValues(Float64, Vector{Float64});
142+
cb = PeriodicCallback(Base.Fix2(affect!, saved_values), 0.1)
143+
prob = ODEProblem(foo!, [0.0], (0.0, 1.0), [1.0, 0.0, 0.0, 0.0], callback = cb)
144+
sol2 = solve(prob, Tsit5())
145+
@test sol.u == sol2.u
146+
@test saved_values.t == sol.prob.kwargs[:disc_saved_values][1].t
147+
@test saved_values.saveval == sol.prob.kwargs[:disc_saved_values][1].saveval
167148

168149
@info "Testing multi-rate hybrid system"
169150
dt = 0.1

0 commit comments

Comments
 (0)