Skip to content

Commit 6e49923

Browse files
authored
Merge pull request #1949 from SciML/myb_fb/clocks
WIP: work toward merging clock processing with the common interface
2 parents 4ab3846 + 17be238 commit 6e49923

File tree

10 files changed

+196
-89
lines changed

10 files changed

+196
-89
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ export @variables, @parameters, @constants
230230
export @named, @nonamespace, @namespace, extend, compose, complete
231231
export debug_system
232232

233-
export Continuous, Discrete, sampletime, input_timedomain, output_timedomain
233+
#export Continuous, Discrete, sampletime, input_timedomain, output_timedomain
234234
#export has_discrete_domain, has_continuous_domain
235235
#export is_discrete_domain, is_continuous_domain, is_hybrid_domain
236236
export Sample, Hold, Shift, ShiftIndex

src/discretedomain.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct Shift <: Operator
2727
steps::Int
2828
Shift(t, steps = 1) = new(value(t), steps)
2929
end
30+
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
3031
function (D::Shift)(x, allow_zero = false)
3132
!allow_zero && D.steps == 0 && return x
3233
Term{symtype(x)}(D, Any[x])

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ for prop in [:eqs
208208
:torn_matching
209209
:tearing_state
210210
:substitutions
211-
:metadata]
211+
:metadata
212+
:discrete_subsystems]
212213
fname1 = Symbol(:get_, prop)
213214
fname2 = Symbol(:has_, prop)
214215
@eval begin

src/systems/clock_inference.jl

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,16 @@ end
88
function ClockInference(ts::TearingState)
99
@unpack fullvars, structure = ts
1010
@unpack graph = structure
11-
eq_domain = Vector{TimeDomain}(undef, nsrcs(graph))
12-
var_domain = Vector{TimeDomain}(undef, ndsts(graph))
11+
eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)]
12+
var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)]
1313
inferred = BitSet()
1414
for (i, v) in enumerate(fullvars)
1515
d = get_time_domain(v)
1616
if d isa Union{AbstractClock, Continuous}
1717
push!(inferred, i)
1818
dd = d
19-
else
20-
dd = Inferred()
19+
var_domain[i] = dd
2120
end
22-
var_domain[i] = dd
2321
end
2422
ClockInference(ts, eq_domain, var_domain, inferred)
2523
end
@@ -28,6 +26,7 @@ function infer_clocks!(ci::ClockInference)
2826
@unpack ts, eq_domain, var_domain, inferred = ci
2927
@unpack fullvars = ts
3028
@unpack graph = ts.structure
29+
isempty(inferred) && return ci
3130
# TODO: add a graph type to do this lazily
3231
var_graph = SimpleGraph(ndsts(graph))
3332
for eq in 𝑠vertices(graph)
@@ -58,7 +57,6 @@ function infer_clocks!(ci::ClockInference)
5857
vd = var_domain[v]
5958
eqs = 𝑑neighbors(graph, v)
6059
isempty(eqs) && continue
61-
#eq = first(eqs)
6260
for eq in eqs
6361
eq_domain[eq] = vd
6462
end
@@ -116,7 +114,6 @@ function split_system(ci::ClockInference)
116114
@assert cid!==0 "Internal error! Variable $(fullvars[i]) doesn't have a inferred time domain."
117115
var_to_cid[i] = cid
118116
v = fullvars[i]
119-
#TODO: remove Inferred*
120117
if istree(v) && (o = operation(v)) isa Operator &&
121118
input_timedomain(o) != output_timedomain(o)
122119
push!(input_idxs[cid], i)
@@ -147,21 +144,28 @@ function split_system(ci::ClockInference)
147144
@set! ts_i.structure.eq_to_diff = eq_to_diff
148145
tss[id] = ts_i
149146
end
150-
return tss, inputs, continuous_id
147+
return tss, inputs, continuous_id, id_to_clock
151148
end
152149

153-
function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = true)
150+
function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
151+
checkbounds = true,
152+
eval_module = @__MODULE__, eval_expression = true)
153+
@static if VERSION < v"1.7"
154+
error("The `generate_discrete_affect` function requires at least Julia 1.7")
155+
end
154156
out = Sym{Any}(:out)
155157
appended_parameters = parameters(syss[continuous_id])
156158
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
157159
offset = length(appended_parameters)
158160
affect_funs = []
159161
svs = []
162+
clocks = TimeDomain[]
160163
for (i, (sys, input)) in enumerate(zip(syss, inputs))
161164
i == continuous_id && continue
165+
push!(clocks, id_to_clock[i])
162166
subs = get_substitutions(sys)
163167
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
164-
let_body = SetArray(!check_bounds, out, rhss(equations(sys)))
168+
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
165169
let_block = Let(assignments, let_body, false)
166170
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
167171
# TODO: filter the needed ones
@@ -190,27 +194,37 @@ function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = tr
190194
cont_to_disc_idxs = (offset + 1):(offset += ni)
191195
input_offset = offset
192196
disc_range = (offset + 1):(offset += ns)
193-
affect! = quote
194-
function affect!(integrator, saved_values)
195-
@unpack u, p, t = integrator
196-
c2d_obs = $cont_to_disc_obs
197-
d2c_obs = $disc_to_cont_obs
198-
c2d_view = view(p, $cont_to_disc_idxs)
199-
d2c_view = view(p, $disc_to_cont_idxs)
200-
disc_state = view(p, $disc_range)
201-
disc = $disc
202-
# Write continuous info to discrete
203-
# Write discrete info to continuous
204-
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
205-
copyto!(d2c_view, d2c_obs(disc_state, p, t))
206-
push!(saved_values.t, t)
207-
push!(saved_values.saveval, Base.@ntuple $ns i->p[$input_offset + i])
208-
disc(disc_state, disc_state, p, t)
209-
end
197+
save_vec = Expr(:ref, :Float64)
198+
for i in 1:ns
199+
push!(save_vec.args, :(p[$(input_offset + i)]))
210200
end
211-
sv = SavedValues(Float64, NTuple{ns, Float64})
201+
affect! = :(function (integrator, saved_values)
202+
@unpack u, p, t = integrator
203+
c2d_obs = $cont_to_disc_obs
204+
d2c_obs = $disc_to_cont_obs
205+
c2d_view = view(p, $cont_to_disc_idxs)
206+
d2c_view = view(p, $disc_to_cont_idxs)
207+
disc_state = view(p, $disc_range)
208+
disc = $disc
209+
# Write continuous info to discrete
210+
# Write discrete info to continuous
211+
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
212+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
213+
push!(saved_values.t, t)
214+
push!(saved_values.saveval, $save_vec)
215+
disc(disc_state, disc_state, p, t)
216+
end)
217+
sv = SavedValues(Float64, Vector{Float64})
212218
push!(affect_funs, affect!)
213219
push!(svs, sv)
214220
end
215-
return map(a -> toexpr(LiteralExpr(a)), affect_funs), svs, appended_parameters
221+
if eval_expression
222+
affects = map(affect_funs) do a
223+
@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a)))
224+
end
225+
else
226+
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
227+
end
228+
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
229+
return affects, clocks, svs, appended_parameters, defaults
216230
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,12 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...
686686
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
687687
end
688688

689+
struct DiscreteSaveAffect{F, S} <: Function
690+
f::F
691+
s::S
692+
end
693+
(d::DiscreteSaveAffect)(args...) = d.f(args..., d.s)
694+
689695
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
690696
tspan = get_tspan(sys),
691697
parammap = DiffEqBase.NullParameters();
@@ -698,14 +704,38 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
698704
has_difference = has_difference,
699705
check_length, kwargs...)
700706
cbs = process_events(sys; callback, has_difference, kwargs...)
707+
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
708+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
709+
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
710+
if clock isa Clock
711+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
712+
else
713+
error("$clock is not a supported clock type.")
714+
end
715+
end
716+
if cbs === nothing
717+
if length(discrete_cbs) == 1
718+
cbs = only(discrete_cbs)
719+
else
720+
cbs = CallbackSet(discrete_cbs...)
721+
end
722+
else
723+
cbs = CallbackSet(cbs, discrete_cbs)
724+
end
725+
else
726+
svs = nothing
727+
end
701728
kwargs = filter_kwargs(kwargs)
702729
pt = something(get_metadata(sys), StandardODEProblem())
703730

704-
if cbs === nothing
705-
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs...)
706-
else
707-
ODEProblem{iip}(f, u0, tspan, p, pt; callback = cbs, kwargs...)
731+
kwargs1 = (;)
732+
if cbs !== nothing
733+
kwargs1 = merge(kwargs1, (callback = cbs,))
734+
end
735+
if svs !== nothing
736+
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
708737
end
738+
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
709739
end
710740
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
711741

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ struct ODESystem <: AbstractODESystem
126126
complete: if a model `sys` is complete, then `sys.x` no longer performs namespacing.
127127
"""
128128
complete::Bool
129+
"""
130+
discrete_subsystems: a list of discrete subsystems
131+
"""
132+
discrete_subsystems::Any
129133

130134
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
131135
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
132136
torn_matching, connector_type, preface, cevents,
133137
devents, metadata = nothing, tearing_state = nothing,
134-
substitutions = nothing, complete = false;
135-
checks::Union{Bool, Int} = true)
138+
substitutions = nothing, complete = false,
139+
discrete_subsystems = nothing; checks::Union{Bool, Int} = true)
136140
if checks == true || (checks & CheckComponents) > 0
137141
check_variables(dvs, iv)
138142
check_parameters(ps, iv)
@@ -145,7 +149,7 @@ struct ODESystem <: AbstractODESystem
145149
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
146150
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
147151
connector_type, preface, cevents, devents, metadata, tearing_state,
148-
substitutions, complete)
152+
substitutions, complete, discrete_subsystems)
149153
end
150154
end
151155

src/systems/systemstructure.jl

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
99
value, InvalidSystemException, isdifferential, _iszero,
1010
isparameter, isconstant,
1111
independent_variables, SparseMatrixCLIL, AbstractSystem,
12-
equations, isirreducible
12+
equations, isirreducible, input_timedomain, TimeDomain
1313
using ..BipartiteGraphs
1414
import ..BipartiteGraphs: invview, complete
1515
using Graphs
@@ -285,7 +285,7 @@ function TearingState(sys; quick_cancel = false, check = true)
285285
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
286286
set_incidence = false
287287
var = only(arguments(var))
288-
var = setmetadata(var, ModelingToolkit.TimeDomain, it)
288+
var = setmetadata(var, TimeDomain, it)
289289
@goto ANOTHER_VAR
290290
end
291291
end
@@ -452,8 +452,59 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
452452
end
453453

454454
# TODO: clean up
455+
function merge_io(io, inputs)
456+
isempty(inputs) && return io
457+
if io === nothing
458+
io = (inputs, [])
459+
else
460+
io = ([inputs; io[1]], io[2])
461+
end
462+
return io
463+
end
464+
455465
function structural_simplify!(state::TearingState, io = nothing; simplify = false,
456466
check_consistency = true, kwargs...)
467+
if state.sys isa ODESystem
468+
ci = ModelingToolkit.ClockInference(state)
469+
ModelingToolkit.infer_clocks!(ci)
470+
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
471+
cont_io = merge_io(io, inputs[continuous_id])
472+
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
473+
check_consistency,
474+
kwargs...)
475+
if length(tss) > 1
476+
# TODO: rename it to something else
477+
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
478+
# Note that the appended_parameters must agree with
479+
# `generate_discrete_affect`!
480+
appended_parameters = parameters(sys)
481+
for (i, state) in enumerate(tss)
482+
if i == continuous_id
483+
discrete_subsystems[i] = sys
484+
continue
485+
end
486+
dist_io = merge_io(io, inputs[i])
487+
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
488+
kwargs...)
489+
append!(appended_parameters, inputs[i], states(ss))
490+
discrete_subsystems[i] = ss
491+
end
492+
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
493+
id_to_clock
494+
@set! sys.ps = appended_parameters
495+
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
496+
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
497+
end
498+
else
499+
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,
500+
kwargs...)
501+
end
502+
has_io = io !== nothing
503+
return has_io ? (sys, input_idxs) : sys
504+
end
505+
506+
function _structural_simplify!(state::TearingState, io; simplify = false,
507+
check_consistency = true, kwargs...)
457508
has_io = io !== nothing
458509
has_io && ModelingToolkit.markio!(state, io...)
459510
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
@@ -464,8 +515,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
464515
sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify)
465516
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
466517
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
467-
ModelingToolkit.invalidate_cache!(sys)
468-
return has_io ? (sys, input_idxs) : sys
518+
ModelingToolkit.invalidate_cache!(sys), input_idxs
469519
end
470520

471521
end # module

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,5 @@ function fast_substitute(expr, pair::Pair)
858858
symtype(expr);
859859
metadata = metadata(expr))
860860
end
861+
862+
normalize_to_differential(s) = s

src/variables.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ isoutput(x) = isvarkind(VariableOutput, x)
3636
isirreducible(x) = isvarkind(VariableIrreducible, x)
3737
state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64
3838

39+
function default_toterm(x)
40+
if istree(x) && (op = operation(x)) isa Operator
41+
if !(op isa Differential)
42+
x = normalize_to_differential(op)(arguments(x)...)
43+
end
44+
Symbolics.diff2term(x)
45+
else
46+
x
47+
end
48+
end
49+
3950
"""
4051
$(SIGNATURES)
4152
@@ -44,7 +55,7 @@ and creates the array of values in the correct order with default values when
4455
applicable.
4556
"""
4657
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
47-
toterm = Symbolics.diff2term, promotetoconcrete = nothing,
58+
toterm = default_toterm, promotetoconcrete = nothing,
4859
tofloat = true, use_union = false)
4960
varlist = collect(map(unwrap, varlist))
5061

0 commit comments

Comments
 (0)