Skip to content

Commit eaecbe3

Browse files
Merge pull request #2469 from AayushSabharwal/as/no-scalarize
feat!: do not scalarize parameters, fix some tests
2 parents 65799a1 + 3ede8ff commit eaecbe3

29 files changed

+1023
-366
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@
4545
equations. For example, `[p[1] => 1.0, p[2] => 2.0]` is no longer allowed in default equations, use
4646
`[p => [1.0, 2.0]]` instead. Also, array equations like for `@variables u[1:2]` have `D(u) ~ A*u` as an
4747
array equation. If the scalarized version is desired, use `scalarize(u)`.
48+
- Parameter dependencies are now supported. They can be specified using the syntax
49+
`(single_parameter => expression_involving_other_parameters)` and a `Vector` of these can be passed to
50+
the `parameter_dependencies` keyword argument of `ODESystem`, `SDESystem` and `JumpSystem`. The dependent
51+
parameters are updated whenever other parameters are modified, e.g. in callbacks.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
106106
StaticArrays = "0.10, 0.11, 0.12, 1.0"
107107
SymbolicIndexingInterface = "0.3.1"
108108
SymbolicUtils = "1.0"
109-
Symbolics = "5.7"
109+
Symbolics = "5.20.1"
110110
URIs = "1"
111111
UnPack = "0.1, 1.0"
112112
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using PrecompileTools, Reexport
3636
using RecursiveArrayTools
3737

3838
using SymbolicIndexingInterface
39-
export independent_variables, unknowns, parameters
39+
export independent_variables, unknowns, parameters, full_parameters
4040
import SymbolicUtils
4141
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
4242
Symbolic, isadd, ismul, ispow, issym, FnType,

src/bipartite_graph.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ function Base.push!(m::Matching, v)
8888
end
8989
end
9090

91-
function complete(m::Matching{U}, N = maximum((x for x in m.match if isa(x, Int)); init=0)) where {U}
91+
function complete(m::Matching{U},
92+
N = maximum((x for x in m.match if isa(x, Int)); init = 0)) where {U}
9293
m.inv_match !== nothing && return m
9394
inv_match = Union{U, Int}[unassigned for _ in 1:N]
9495
for (i, eq) in enumerate(m.match)

src/structural_transformation/partial_state_selection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function pss_graph_modia!(structure::SystemStructure, maximal_top_matching, varl
5151
old_level_vars = ()
5252
ict = IncrementalCycleTracker(
5353
DiCMOBiGraph{true}(graph,
54-
complete(Matching(ndsts(graph)), nsrcs(graph))),
54+
complete(Matching(ndsts(graph)), nsrcs(graph))),
5555
dir = :in)
5656

5757
while level >= 0

src/systems/abstractsystem.jl

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
193193
h = getsymbolhash(sym)
194194
return haskey(ic.unknown_idx, h) ||
195195
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) ||
196-
hasname(sym) && is_variable(sys, getname(sym))
196+
(istree(sym) && operation(sym) === getindex &&
197+
is_variable(sys, first(arguments(sym))))
197198
else
198199
return any(isequal(sym), variable_symbols(sys)) ||
199200
hasname(sym) && is_variable(sys, getname(sym))
@@ -214,18 +215,15 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
214215
if has_index_cache(sys) && get_index_cache(sys) !== nothing
215216
ic = get_index_cache(sys)
216217
h = getsymbolhash(sym)
217-
return if haskey(ic.unknown_idx, h)
218-
ic.unknown_idx[h]
219-
else
220-
h = getsymbolhash(default_toterm(sym))
221-
if haskey(ic.unknown_idx, h)
222-
ic.unknown_idx[h]
223-
elseif hasname(sym)
224-
variable_index(sys, getname(sym))
225-
else
226-
nothing
227-
end
228-
end
218+
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
219+
220+
h = getsymbolhash(default_toterm(sym))
221+
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
222+
sym = unwrap(sym)
223+
istree(sym) && operation(sym) === getindex || return nothing
224+
idx = variable_index(sys, first(arguments(sym)))
225+
idx === nothing && return nothing
226+
return idx[arguments(sym)[(begin + 1):end]...]
229227
end
230228
idx = findfirst(isequal(sym), variable_symbols(sys))
231229
if idx === nothing && hasname(sym)
@@ -264,8 +262,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
264262
else
265263
h = getsymbolhash(default_toterm(sym))
266264
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
267-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
268-
hasname(sym) && is_parameter(sys, getname(sym))
265+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
269266
end
270267
end
271268
return any(isequal(sym), parameter_symbols(sys)) ||
@@ -286,27 +283,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
286283
if has_index_cache(sys) && get_index_cache(sys) !== nothing
287284
ic = get_index_cache(sys)
288285
h = getsymbolhash(sym)
289-
return if haskey(ic.param_idx, h)
290-
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
291-
elseif haskey(ic.discrete_idx, h)
292-
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
293-
elseif haskey(ic.constant_idx, h)
294-
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
295-
elseif haskey(ic.dependent_idx, h)
296-
ParameterIndex(nothing, ic.dependent_idx[h])
286+
return if (idx = ParameterIndex(ic, sym)) !== nothing
287+
idx
288+
elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing
289+
idx
297290
else
298-
h = getsymbolhash(default_toterm(sym))
299-
if haskey(ic.param_idx, h)
300-
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
301-
elseif haskey(ic.discrete_idx, h)
302-
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
303-
elseif haskey(ic.constant_idx, h)
304-
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
305-
elseif haskey(ic.dependent_idx, h)
306-
ParameterIndex(nothing, ic.dependent_idx[h])
307-
else
308-
nothing
309-
end
291+
nothing
310292
end
311293
end
312294

@@ -329,7 +311,7 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
329311
end
330312

331313
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
332-
return parameters(sys)
314+
return full_parameters(sys)
333315
end
334316

335317
function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
@@ -419,6 +401,7 @@ for prop in [:eqs
419401
:metadata
420402
:gui_metadata
421403
:discrete_subsystems
404+
:parameter_dependencies
422405
:solved_unknowns
423406
:split_idxs
424407
:parent
@@ -703,9 +686,12 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys
703686
# metadata from the rescoped variable
704687
rescoped = renamespace(n, O)
705688
similarterm(O, operation(rescoped), renamed,
706-
metadata = metadata(rescoped))::T
689+
metadata = metadata(rescoped))
690+
elseif Symbolics.isarraysymbolic(O)
691+
# promote_symtype doesn't work for array symbolics
692+
similarterm(O, operation(O), renamed, symtype(O), metadata = metadata(O))
707693
else
708-
similarterm(O, operation(O), renamed, metadata = metadata(O))::T
694+
similarterm(O, operation(O), renamed, metadata = metadata(O))
709695
end
710696
elseif isvariable(O)
711697
renamespace(n, O)
@@ -747,7 +733,29 @@ function parameters(sys::AbstractSystem)
747733
ps = first.(ps)
748734
end
749735
systems = get_systems(sys)
750-
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
736+
result = unique(isempty(systems) ? ps :
737+
[ps; reduce(vcat, namespace_parameters.(systems))])
738+
if has_parameter_dependencies(sys) &&
739+
(pdeps = get_parameter_dependencies(sys)) !== nothing
740+
filter(result) do sym
741+
!haskey(pdeps, sym)
742+
end
743+
else
744+
result
745+
end
746+
end
747+
748+
function dependent_parameters(sys::AbstractSystem)
749+
if has_parameter_dependencies(sys) &&
750+
(pdeps = get_parameter_dependencies(sys)) !== nothing
751+
collect(keys(pdeps))
752+
else
753+
[]
754+
end
755+
end
756+
757+
function full_parameters(sys::AbstractSystem)
758+
vcat(parameters(sys), dependent_parameters(sys))
751759
end
752760

753761
# required in `src/connectors.jl:437`
@@ -1518,13 +1526,12 @@ function linearization_function(sys::AbstractSystem, inputs,
15181526
sys = ssys
15191527
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
15201528
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1529+
ps = parameters(sys)
15211530
if has_index_cache(sys) && get_index_cache(sys) !== nothing
15221531
p = MTKParameters(sys, p)
1523-
ps = reorder_parameters(sys, parameters(sys))
15241532
else
15251533
p = _p
15261534
p, split_idxs = split_parameters_by_type(p)
1527-
ps = parameters(sys)
15281535
if p isa Tuple
15291536
ps = Base.Fix1(getindex, ps).(split_idxs)
15301537
ps = (ps...,) #if p is Tuple, ps should be Tuple
@@ -1610,7 +1617,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
16101617
kwargs...)
16111618
sts = unknowns(sys)
16121619
t = get_iv(sys)
1613-
ps = parameters(sys)
1620+
ps = full_parameters(sys)
16141621
p = reorder_parameters(sys, ps)
16151622

16161623
fun = generate_function(sys, sts, ps; expression = Val{false})[1]
@@ -2121,3 +2128,17 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
21212128
error("substituting symbols is not supported for $(typeof(sys))")
21222129
end
21232130
end
2131+
2132+
function process_parameter_dependencies(pdeps, ps)
2133+
pdeps === nothing && return pdeps, ps
2134+
if pdeps isa Vector && eltype(pdeps) <: Pair
2135+
pdeps = Dict(pdeps)
2136+
elseif !(pdeps isa Dict)
2137+
error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`")
2138+
end
2139+
2140+
ps = filter(ps) do p
2141+
!haskey(pdeps, p)
2142+
end
2143+
return pdeps, ps
2144+
end

src/systems/callbacks.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
390390
if has_index_cache(sys) && get_index_cache(sys) !== nothing
391391
ic = get_index_cache(sys)
392392
update_inds = map(update_vars) do sym
393-
@unpack portion, idx = parameter_index(sys, sym)
394-
if portion == SciMLStructures.Discrete()
395-
idx += length(ic.param_idx)
396-
end
397-
idx
393+
pind = parameter_index(sys, sym)
394+
discrete_linear_index(ic, pind)
398395
end
399396
else
400397
psind = Dict(reverse(en) for en in enumerate(ps))
@@ -436,14 +433,14 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
436433
end
437434

438435
function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
439-
ps = parameters(sys); kwargs...)
436+
ps = full_parameters(sys); kwargs...)
440437
cbs = continuous_events(sys)
441438
isempty(cbs) && return nothing
442439
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
443440
end
444441

445442
function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
446-
ps = parameters(sys); kwargs...)
443+
ps = full_parameters(sys); kwargs...)
447444
eqs = map(cb -> cb.eqs, cbs)
448445
num_eqs = length.(eqs)
449446
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
@@ -559,7 +556,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
559556
end
560557

561558
function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
562-
ps = parameters(sys); kwargs...)
559+
ps = full_parameters(sys); kwargs...)
563560
has_discrete_events(sys) || return nothing
564561
symcbs = discrete_events(sys)
565562
isempty(symcbs) && return nothing

0 commit comments

Comments
 (0)