Skip to content

Commit 93bfbd3

Browse files
YingboMabaggepinnen
andcommitted
WIP: work toward merging clock processing with the common interface
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
1 parent 5d99e9c commit 93bfbd3

File tree

4 files changed

+99
-35
lines changed

4 files changed

+99
-35
lines changed

src/systems/clock_inference.jl

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ function split_system(ci::ClockInference)
150150
return tss, inputs, continuous_id
151151
end
152152

153-
function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = true)
153+
function generate_discrete_affect(syss, inputs, continuous_id; checkbounds = true,
154+
eval_module = @__MODULE__, eval_expression = true)
154155
out = Sym{Any}(:out)
155156
appended_parameters = parameters(syss[continuous_id])
156157
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
@@ -161,7 +162,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = tr
161162
i == continuous_id && continue
162163
subs = get_substitutions(sys)
163164
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
164-
let_body = SetArray(!check_bounds, out, rhss(equations(sys)))
165+
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
165166
let_block = Let(assignments, let_body, false)
166167
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
167168
# TODO: filter the needed ones
@@ -190,27 +191,37 @@ function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = tr
190191
cont_to_disc_idxs = (offset + 1):(offset += ni)
191192
input_offset = offset
192193
disc_range = (offset + 1):(offset += ns)
193-
affect! = quote
194-
function affect!(integrator, saved_values)
195-
@unpack u, p, t = integrator
196-
c2d_obs = $cont_to_disc_obs
197-
d2c_obs = $disc_to_cont_obs
198-
c2d_view = view(p, $cont_to_disc_idxs)
199-
d2c_view = view(p, $disc_to_cont_idxs)
200-
disc_state = view(p, $disc_range)
201-
disc = $disc
202-
# Write continuous info to discrete
203-
# Write discrete info to continuous
204-
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
205-
copyto!(d2c_view, d2c_obs(disc_state, p, t))
206-
push!(saved_values.t, t)
207-
push!(saved_values.saveval, Base.@ntuple $ns i->p[$input_offset + i])
208-
disc(disc_state, disc_state, p, t)
209-
end
194+
save_tuple = Expr(:tuple)
195+
for i in 1:ns
196+
push!(save_tuple.args, :(p[$(input_offset + i)]))
210197
end
198+
affect! = :(function (integrator, saved_values)
199+
@unpack u, p, t = integrator
200+
c2d_obs = $cont_to_disc_obs
201+
d2c_obs = $disc_to_cont_obs
202+
c2d_view = view(p, $cont_to_disc_idxs)
203+
d2c_view = view(p, $disc_to_cont_idxs)
204+
disc_state = view(p, $disc_range)
205+
disc = $disc
206+
# Write continuous info to discrete
207+
# Write discrete info to continuous
208+
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
209+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
210+
push!(saved_values.t, t)
211+
push!(saved_values.saveval, $save_tuple)
212+
disc(disc_state, disc_state, p, t)
213+
end)
211214
sv = SavedValues(Float64, NTuple{ns, Float64})
212215
push!(affect_funs, affect!)
213216
push!(svs, sv)
214217
end
215-
return map(a -> toexpr(LiteralExpr(a)), affect_funs), svs, appended_parameters
218+
if eval_expression
219+
affects = map(affect_funs) do a
220+
@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a)))
221+
end
222+
else
223+
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
224+
end
225+
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
226+
return affects, svs, appended_parameters, defaults
216227
end

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ struct ODESystem <: AbstractODESystem
126126
complete: if a model `sys` is complete, then `sys.x` no longer performs namespacing.
127127
"""
128128
complete::Bool
129+
"""
130+
discrete_subsystems: a list of discrete subsystems
131+
"""
132+
discrete_subsystems::Union{Nothing, Tuple{Vector{ODESystem}, Vector{Any}, Int}}
129133

130134
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
131135
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
132136
torn_matching, connector_type, preface, cevents,
133137
devents, metadata = nothing, tearing_state = nothing,
134-
substitutions = nothing, complete = false;
135-
checks::Union{Bool, Int} = true)
138+
substitutions = nothing, complete = false,
139+
discrete_subsystems = nothing; checks::Union{Bool, Int} = true)
136140
if checks == true || (checks & CheckComponents) > 0
137141
check_variables(dvs, iv)
138142
check_parameters(ps, iv)
@@ -145,7 +149,7 @@ struct ODESystem <: AbstractODESystem
145149
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
146150
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
147151
connector_type, preface, cevents, devents, metadata, tearing_state,
148-
substitutions, complete)
152+
substitutions, complete, discrete_subsystems)
149153
end
150154
end
151155

src/systems/systemstructure.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,49 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
452452
end
453453

454454
# TODO: clean up
455+
function merge_io(io, inputs)
456+
isempty(inputs) && return io
457+
if io === nothing
458+
io = (inputs, [])
459+
else
460+
io = ([inputs; io[1]], io[2])
461+
end
462+
return io
463+
end
464+
455465
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
456466
check_consistency = true, kwargs...)
467+
if state.sys isa ODESystem
468+
ci = ModelingToolkit.ClockInference(state)
469+
ModelingToolkit.infer_clocks!(ci)
470+
tss, inputs, continuous_id = ModelingToolkit.split_system(ci)
471+
cont_io = merge_io(io, inputs[continuous_id])
472+
sys = _structural_simplify!(tss[continous_id], cont_io; simplify, check_consistency,
473+
kwargs...)
474+
if length(tss) > 1
475+
# TODO: rename it to something else
476+
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
477+
for (i, state) in enumerate(tss)
478+
if i == continuous_id
479+
discrete_subsystems[i] = sys
480+
continue
481+
end
482+
dist_io = merge_io(io, inputs[i])
483+
ss = _structural_simplify!(state, dist_io; simplify, check_consistency,
484+
kwargs...)
485+
push!(discrete_subsystems, ss)
486+
end
487+
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id
488+
end
489+
else
490+
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,
491+
kwargs...)
492+
end
493+
return has_io ? (sys, input_idxs) : sys
494+
end
495+
496+
function _structural_simplify!(state::TearingState, io; simplify = false,
497+
check_consistency = true, kwargs...)
457498
has_io = io !== nothing
458499
has_io && ModelingToolkit.markio!(state, io...)
459500
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
@@ -464,8 +505,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
464505
sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify)
465506
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
466507
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
467-
ModelingToolkit.invalidate_cache!(sys)
468-
return has_io ? (sys, input_idxs) : sys
508+
ModelingToolkit.invalidate_cache!(sys), input_idxs
469509
end
470510

471511
end # module

test/clock.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ By inference:
6161
=> Shift(x, 0, dt) := (Shift(x, -1, dt) + dt) / (1 - dt) # Discrete system
6262
=#
6363

64+
using ModelingToolkit.SystemStructures
6465
ci, varmap = infer_clocks(sys)
6566
eqmap = ci.eq_domain
6667
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
67-
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
68+
sss, = SystemStructures._structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
6869
@test equations(sss) == [D(x) ~ u - x]
69-
sss, = ModelingToolkit.structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
70+
sss, = SystemStructures._structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
7071
@test isempty(equations(sss))
7172
@test observed(sss) == [r ~ 1.0; yd ~ Sample(t, dt)(y); ud ~ kp * (r - yd)]
7273

@@ -112,23 +113,31 @@ eqs = [yd ~ Sample(t, dt)(y)
112113
]
113114
@named sys = ODESystem(eqs)
114115
ci, varmap = infer_clocks(sys)
115-
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
116-
syss = map(i -> ModelingToolkit.structural_simplify!(deepcopy(tss[i]), (inputs[i], ()))[1],
116+
tss, inputs, continuous_id = ModelingToolkit.split_system(deepcopy(ci))
117+
syss = map(i -> SystemStructures._structural_simplify!(deepcopy(tss[i]), (inputs[i], ()))[1],
117118
eachindex(tss))
118119
sys1, sys2 = syss
119120
@test length(states(sys2)) == 2
120121
z, z_t = states(sys2)
121122
S = Shift(t, 1)
122123
@test full_equations(sys2) == [S(z) ~ z_t; S(z_t) ~ z + Sample(t, dt)(y)]
123124
# TODO: set Hold(ud)
124-
prob = ODEProblem(sys1, [x => 0.0, y => 0.0], (0.0, 1.0), [kp => 1.0, Hold(ud) => 0.0]);
125+
affects, svs, pp, defaults = ModelingToolkit.generate_discrete_affect(syss, inputs,
126+
continuous_id);
127+
@set! sys1.ps = pp
128+
prob = ODEProblem(sys1, [x => 0.0, y => 0.0], (0.0, 1.0), [pp .=> 0.0; kp => 1.0]);
125129
using OrdinaryDiffEq, DiffEqCallbacks
126-
exprs, svs, pp = ModelingToolkit.generate_discrete_affect(syss, inputs, 1);
127-
prob = remake(prob, p = zeros(Float64, length(pp)));
128-
prob.p[1] = 1.0;
129-
gen_affect! = Base.Fix2(eval(exprs[1]), svs[1]);
130+
struct DiscreteSaveAffect{F, S} <: Function
131+
f::F
132+
s::S
133+
end
134+
(d::DiscreteSaveAffect)(args...) = d.f(args..., d.s)
135+
gen_affect! = DiscreteSaveAffect(affects[1], svs[1]);
130136
cb = PeriodicCallback(gen_affect!, 0.1);
131-
sol2 = solve(prob, Tsit5(), callback = cb);
137+
prob = remake(prob, callback = cb);
138+
sol2 = solve(prob, Tsit5());
139+
# For all inputs in parameters, just initialize them to 0.0, and then set them
140+
# in the callback.
132141

133142
# kp is the only real parameter
134143
function foo!(du, u, p, t)

0 commit comments

Comments
 (0)