Skip to content

Commit 29040fc

Browse files
Merge pull request #2927 from AayushSabharwal/as/fix-indexing-ci
fix: fix downstream indexing tests
2 parents 50a4b12 + 750e82f commit 29040fc

File tree

8 files changed

+60
-303
lines changed

8 files changed

+60
-303
lines changed

src/systems/abstractsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,11 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
563563
if idx.portion isa SciMLStructures.Discrete &&
564564
idx.idx[2] == idx.idx[3] == nothing
565565
return nothing
566+
elseif idx.portion isa SciMLStructures.Tunable
567+
return ParameterIndex(
568+
idx.portion, idx.idx[arguments(sym)[(begin + 1):end]...])
566569
else
567-
ParameterIndex(
570+
return ParameterIndex(
568571
idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
569572
end
570573
else

src/systems/callbacks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ namespace_affects(::Nothing, s) = nothing
190190
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
191191
SymbolicContinuousCallback(
192192
namespace_equation.(equations(cb), (s,)),
193-
namespace_affects(affects(cb), s),
194-
namespace_affects(affect_negs(cb), s))
193+
namespace_affects(affects(cb), s);
194+
affect_neg = namespace_affects(affect_negs(cb), s))
195195
end
196196

197197
"""

src/systems/clock_inference.jl

Lines changed: 0 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -195,141 +195,3 @@ function split_system(ci::ClockInference{S}) where {S}
195195
end
196196
return tss, inputs, continuous_id, id_to_clock
197197
end
198-
199-
function generate_discrete_affect(
200-
osys::AbstractODESystem, syss, inputs, continuous_id, id_to_clock;
201-
checkbounds = true,
202-
eval_module = @__MODULE__, eval_expression = false)
203-
@static if VERSION < v"1.7"
204-
error("The `generate_discrete_affect` function requires at least Julia 1.7")
205-
end
206-
has_index_cache(osys) && get_index_cache(osys) !== nothing ||
207-
error("Hybrid systems require `split = true`")
208-
out = Sym{Any}(:out)
209-
appended_parameters = full_parameters(syss[continuous_id])
210-
offset = length(appended_parameters)
211-
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
212-
for p in appended_parameters)
213-
affect_funs = []
214-
clocks = TimeDomain[]
215-
for (i, (sys, input)) in enumerate(zip(syss, inputs))
216-
i == continuous_id && continue
217-
push!(clocks, id_to_clock[i])
218-
subs = get_substitutions(sys)
219-
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
220-
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
221-
let_block = Let(assignments, let_body, false)
222-
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
223-
# TODO: filter the needed ones
224-
fullvars = Set{Any}(eq.lhs for eq in observed(sys))
225-
for s in unknowns(sys)
226-
push!(fullvars, s)
227-
end
228-
needed_disc_to_cont_obs = []
229-
disc_to_cont_idxs = ParameterIndex[]
230-
for v in inputs[continuous_id]
231-
_v = arguments(v)[1]
232-
if _v in fullvars
233-
push!(needed_disc_to_cont_obs, _v)
234-
push!(disc_to_cont_idxs, param_to_idx[v])
235-
continue
236-
end
237-
238-
# If the held quantity is calculated through observed
239-
# it will be shifted forward by 1
240-
_v = Shift(get_iv(sys), 1)(_v)
241-
if _v in fullvars
242-
push!(needed_disc_to_cont_obs, _v)
243-
push!(disc_to_cont_idxs, param_to_idx[v])
244-
continue
245-
end
246-
end
247-
append!(appended_parameters, input)
248-
cont_to_disc_obs = build_explicit_observed_function(
249-
osys,
250-
needed_cont_to_disc_obs,
251-
throw = false,
252-
expression = true,
253-
output_type = SVector)
254-
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
255-
throw = false,
256-
expression = true,
257-
output_type = SVector,
258-
op = Shift,
259-
ps = reorder_parameters(osys, appended_parameters))
260-
ni = length(input)
261-
ns = length(unknowns(sys))
262-
disc = Func(
263-
[
264-
out,
265-
DestructuredArgs(unknowns(osys)),
266-
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
267-
get_iv(sys)
268-
],
269-
[],
270-
let_block) |> toexpr
271-
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
272-
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
273-
save_expr = :($(SciMLBase.save_discretes!)(integrator, $i))
274-
empty_disc = isempty(disc_range)
275-
276-
# @show disc_to_cont_idxs
277-
# @show cont_to_disc_idxs
278-
# @show disc_range
279-
affect! = :(function (integrator)
280-
@unpack u, p, t = integrator
281-
c2d_obs = $cont_to_disc_obs
282-
d2c_obs = $disc_to_cont_obs
283-
# TODO: find a way to do this without allocating
284-
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
285-
disc = $disc
286-
287-
# Write continuous into to discrete: handles `Sample`
288-
# Write discrete into to continuous
289-
# Update discrete unknowns
290-
291-
# At a tick, c2d must come first
292-
# state update comes in the middle
293-
# d2c comes last
294-
# @show t
295-
# @show "incoming", p
296-
result = c2d_obs(u, p..., t)
297-
for (val, i) in zip(result, $cont_to_disc_idxs)
298-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
299-
end
300-
$(if !empty_disc
301-
quote
302-
disc(disc_unknowns, u, p..., t)
303-
for (val, i) in zip(disc_unknowns, $disc_range)
304-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
305-
end
306-
end
307-
end)
308-
# @show "after c2d", p
309-
# @show "after state update", p
310-
result = d2c_obs(disc_unknowns, p..., t)
311-
for (val, i) in zip(result, $disc_to_cont_idxs)
312-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
313-
end
314-
315-
$save_expr
316-
317-
# @show "after d2c", p
318-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
319-
$(SciMLStructures.Discrete()), p)
320-
repack(discretes)
321-
end)
322-
323-
push!(affect_funs, affect!)
324-
end
325-
if eval_expression
326-
affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs)
327-
else
328-
affects = map(affect_funs) do a
329-
drop_expr(RuntimeGeneratedFunction(
330-
eval_module, eval_module, toexpr(LiteralExpr(a))))
331-
end
332-
end
333-
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
334-
return affects, clocks, appended_parameters, defaults
335-
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -782,12 +782,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
782782
varlist = collect(map(unwrap, dvs))
783783
missingvars = setdiff(varlist, collect(keys(varmap)))
784784

785-
# Append zeros to the variables which are determined by the initialization system
786-
# This essentially bypasses the check for if initial conditions are defined for DAEs
787-
# since they will be checked in the initialization problem's construction
788-
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
789-
ci = infer_clocks!(ClockInference(TearingState(sys)))
790-
791785
if eltype(parammap) <: Pair
792786
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
793787
elseif parammap isa AbstractArray
@@ -798,38 +792,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
798792
end
799793
end
800794

801-
if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
802-
clockedparammap = Dict()
803-
defs = ModelingToolkit.get_defaults(sys)
804-
for v in ps
805-
v = unwrap(v)
806-
is_discrete_domain(v) || continue
807-
op = operation(v)
808-
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
809-
haskey(parammap, v)
810-
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
811-
end
812-
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
813-
if parammap != SciMLBase.NullParameters() &&
814-
(val = get(parammap, shiftedv, nothing)) !== nothing
815-
clockedparammap[v] = val
816-
elseif op isa Shift
817-
root = arguments(v)[1]
818-
haskey(defs, root) || error("Initial condition for $v not provided.")
819-
clockedparammap[v] = defs[root]
820-
end
821-
end
822-
parammap = if parammap == SciMLBase.NullParameters()
823-
clockedparammap
824-
else
825-
merge(parammap, clockedparammap)
826-
end
827-
end
828-
# TODO: make it work with clocks
829795
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
830796
if sys isa ODESystem && build_initializeprob &&
831797
(((implicit_dae || !isempty(missingvars)) &&
832-
all(==(Continuous), ci.var_domain) &&
833798
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
834799
!isempty(initialization_equations(sys))) && t !== nothing
835800
if eltype(u0map) <: Number
@@ -1010,29 +975,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
1010975
t = tspan !== nothing ? tspan[1] : tspan,
1011976
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
1012977
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
1013-
inits = []
1014-
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1015-
affects, clocks = ModelingToolkit.generate_discrete_affect(
1016-
sys, dss...; eval_expression, eval_module)
1017-
discrete_cbs = map(affects, clocks) do affect, clock
1018-
@match clock begin
1019-
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
1020-
final_affect = true, initial_affect = true)
1021-
&SolverStepClock => DiscreteCallback(Returns(true), affect,
1022-
initialize = (c, u, t, integrator) -> affect(integrator))
1023-
_ => error("$clock is not a supported clock type.")
1024-
end
1025-
end
1026-
if cbs === nothing
1027-
if length(discrete_cbs) == 1
1028-
cbs = only(discrete_cbs)
1029-
else
1030-
cbs = CallbackSet(discrete_cbs...)
1031-
end
1032-
else
1033-
cbs = CallbackSet(cbs, discrete_cbs...)
1034-
end
1035-
end
978+
1036979
kwargs = filter_kwargs(kwargs)
1037980
pt = something(get_metadata(sys), StandardODEProblem())
1038981

@@ -1112,40 +1055,14 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11121055
h(p, t) = h_oop(p, t)
11131056
h(p::MTKParameters, t) = h_oop(p..., t)
11141057
u0 = h(p, tspan[1])
1058+
11151059
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
1116-
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1117-
affects, clocks = ModelingToolkit.generate_discrete_affect(
1118-
sys, dss...; eval_expression, eval_module)
1119-
discrete_cbs = map(affects, clocks) do affect, clock
1120-
@match clock begin
1121-
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
1122-
final_affect = true, initial_affect = true)
1123-
&SolverStepClock => DiscreteCallback(Returns(true), affect,
1124-
initialize = (c, u, t, integrator) -> affect(integrator))
1125-
_ => error("$clock is not a supported clock type.")
1126-
end
1127-
end
1128-
if cbs === nothing
1129-
if length(discrete_cbs) == 1
1130-
cbs = only(discrete_cbs)
1131-
else
1132-
cbs = CallbackSet(discrete_cbs...)
1133-
end
1134-
else
1135-
cbs = CallbackSet(cbs, discrete_cbs)
1136-
end
1137-
else
1138-
svs = nothing
1139-
end
11401060
kwargs = filter_kwargs(kwargs)
11411061

11421062
kwargs1 = (;)
11431063
if cbs !== nothing
11441064
kwargs1 = merge(kwargs1, (callback = cbs,))
11451065
end
1146-
if svs !== nothing
1147-
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
1148-
end
11491066
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
11501067
end
11511068

@@ -1175,40 +1092,14 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11751092
h(p::MTKParameters, t) = h_oop(p..., t)
11761093
h(out, p::MTKParameters, t) = h_iip(out, p..., t)
11771094
u0 = h(p, tspan[1])
1095+
11781096
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
1179-
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1180-
affects, clocks = ModelingToolkit.generate_discrete_affect(
1181-
sys, dss...; eval_expression, eval_module)
1182-
discrete_cbs = map(affects, clocks) do affect, clock
1183-
@match clock begin
1184-
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
1185-
final_affect = true, initial_affect = true)
1186-
&SolverStepClock => DiscreteCallback(Returns(true), affect,
1187-
initialize = (c, u, t, integrator) -> affect(integrator))
1188-
_ => error("$clock is not a supported clock type.")
1189-
end
1190-
end
1191-
if cbs === nothing
1192-
if length(discrete_cbs) == 1
1193-
cbs = only(discrete_cbs)
1194-
else
1195-
cbs = CallbackSet(discrete_cbs...)
1196-
end
1197-
else
1198-
cbs = CallbackSet(cbs, discrete_cbs)
1199-
end
1200-
else
1201-
svs = nothing
1202-
end
12031097
kwargs = filter_kwargs(kwargs)
12041098

12051099
kwargs1 = (;)
12061100
if cbs !== nothing
12071101
kwargs1 = merge(kwargs1, (callback = cbs,))
12081102
end
1209-
if svs !== nothing
1210-
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
1211-
end
12121103

12131104
noiseeqs = get_noiseeqs(sys)
12141105
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -401,15 +401,6 @@ function build_explicit_observed_function(sys, ts;
401401
dep_vars = scalarize(setdiff(vars, ivs))
402402

403403
obs = param_only ? Equation[] : observed(sys)
404-
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
405-
# each subsystem is topologically sorted independently. We can append the
406-
# equations to override the `lhs ~ 0` equations in `observed(sys)`
407-
syss, _, continuous_id, _... = dss
408-
for (i, subsys) in enumerate(syss)
409-
i == continuous_id && continue
410-
append!(obs, observed(subsys))
411-
end
412-
end
413404

414405
cs = collect_constants(obs)
415406
if !isempty(cs) > 0

0 commit comments

Comments
 (0)