Skip to content

Commit 86c82ce

Browse files
committed
Merge remote-tracking branch 'origin/master' into MTK
2 parents b3da813 + e9fe9a1 commit 86c82ce

23 files changed

+409
-44
lines changed

.github/workflows/Tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ jobs:
3232
group:
3333
- InterfaceI
3434
- InterfaceII
35+
- Initialization
3536
- SymbolicIndexingInterface
3637
- Extended
3738
- Extensions

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
4-
version = "9.53.0"
4+
version = "9.54.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -64,6 +64,7 @@ BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
6464
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6565
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
6666
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
67+
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
6768
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
6869

6970
[extensions]
@@ -72,6 +73,7 @@ MTKChainRulesCoreExt = "ChainRulesCore"
7273
MTKDeepDiffsExt = "DeepDiffs"
7374
MTKHomotopyContinuationExt = "HomotopyContinuation"
7475
MTKLabelledArraysExt = "LabelledArrays"
76+
MTKInfiniteOptExt = "InfiniteOpt"
7577

7678
[compat]
7779
AbstractTrees = "0.3, 0.4"
@@ -104,6 +106,7 @@ FunctionWrappers = "1.1"
104106
FunctionWrappersWrappers = "0.1"
105107
Graphs = "1.5.2"
106108
HomotopyContinuation = "2.11"
109+
InfiniteOpt = "0.5"
107110
InteractiveUtils = "1"
108111
JuliaFormatter = "1.0.47"
109112
JumpProcesses = "9.13.1"

docs/src/tutorials/initialization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ the `initialization_eqs` keyword argument, for example:
113113

114114
```@example init
115115
prob = ODEProblem(pend, [x => 1], (0.0, 1.5), [g => 1], guesses = [λ => 0, y => 1],
116-
initialization_eqs = [y ~ 1])
116+
initialization_eqs = [y ~ 0])
117117
sol = solve(prob, Rodas5P())
118118
plot(sol, idxs = (x, y))
119119
```

ext/MTKInfiniteOptExt.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module MTKInfiniteOptExt
2+
import ModelingToolkit
3+
import SymbolicUtils
4+
import NaNMath
5+
import InfiniteOpt
6+
import InfiniteOpt: JuMP, GeneralVariableRef
7+
8+
# This file contains method definitions to make it possible to trace through functions generated by MTK using JuMP variables
9+
10+
for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
11+
f = nameof(ff)
12+
# These need to be defined so that JuMP can trace through functions built by Symbolics
13+
@eval NaNMath.$f(x::GeneralVariableRef) = Base.$f(x)
14+
end
15+
16+
# JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both.
17+
function Base.isequal(::SymbolicUtils.Symbolic,
18+
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr})
19+
false
20+
end
21+
function Base.isequal(
22+
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr},
23+
::SymbolicUtils.Symbolic)
24+
false
25+
end
26+
end

src/structural_transformation/symbolics_tearing.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ function tearing_sub(expr, dict, s)
8888
s ? simplify(expr) : expr
8989
end
9090

91+
function tearing_substitute_expr(sys::AbstractSystem, expr; simplify = false)
92+
empty_substitutions(sys) && return expr
93+
substitutions = get_substitutions(sys)
94+
@unpack subs = substitutions
95+
solved = Dict(eq.lhs => eq.rhs for eq in subs)
96+
return tearing_sub(expr, solved, simplify)
97+
end
98+
9199
"""
92100
$(TYPEDSIGNATURES)
93101

src/structural_transformation/utils.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,41 @@ end
5858
###
5959
### Structural check
6060
###
61-
function check_consistency(state::TransformationState, orig_inputs)
61+
62+
"""
63+
$(TYPEDSIGNATURES)
64+
65+
Check if the `state` represents a singular system, and return the unmatched variables.
66+
"""
67+
function singular_check(state::TransformationState)
68+
@unpack graph, var_to_diff = state.structure
69+
fullvars = get_fullvars(state)
70+
# This is defined to check if Pantelides algorithm terminates. For more
71+
# details, check the equation (15) of the original paper.
72+
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
73+
map(collect, edges(var_to_diff))])
74+
extended_var_eq_matching = maximal_matching(extended_graph)
75+
76+
nvars = ndsts(graph)
77+
unassigned_var = []
78+
for (vj, eq) in enumerate(extended_var_eq_matching)
79+
vj > nvars && break
80+
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
81+
push!(unassigned_var, fullvars[vj])
82+
end
83+
end
84+
return unassigned_var
85+
end
86+
87+
"""
88+
$(TYPEDSIGNATURES)
89+
90+
Check the consistency of `state`, given the inputs `orig_inputs`. If `nothrow == false`,
91+
throws an error if the system is under-/over-determined or singular. In this case, if the
92+
function returns it will return `true`. If `nothrow == true`, it will return `false`
93+
instead of throwing an error. The singular case will print a warning.
94+
"""
95+
function check_consistency(state::TransformationState, orig_inputs; nothrow = false)
6296
fullvars = get_fullvars(state)
6397
neqs = n_concrete_eqs(state)
6498
@unpack graph, var_to_diff = state.structure
@@ -72,6 +106,7 @@ function check_consistency(state::TransformationState, orig_inputs)
72106
is_balanced = n_highest_vars == neqs
73107

74108
if neqs > 0 && !is_balanced
109+
nothrow && return false
75110
varwhitelist = var_to_diff .== nothing
76111
var_eq_matching = maximal_matching(graph, eq -> true, v -> varwhitelist[v]) # not assigned
77112
# Just use `error_reporting` to do conditional
@@ -85,22 +120,12 @@ function check_consistency(state::TransformationState, orig_inputs)
85120
error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
86121
end
87122

88-
# This is defined to check if Pantelides algorithm terminates. For more
89-
# details, check the equation (15) of the original paper.
90-
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
91-
map(collect, edges(var_to_diff))])
92-
extended_var_eq_matching = maximal_matching(extended_graph)
93-
94-
nvars = ndsts(graph)
95-
unassigned_var = []
96-
for (vj, eq) in enumerate(extended_var_eq_matching)
97-
vj > nvars && break
98-
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
99-
push!(unassigned_var, fullvars[vj])
100-
end
101-
end
123+
unassigned_var = singular_check(state)
102124

103125
if !isempty(unassigned_var) || !is_balanced
126+
if nothrow
127+
return false
128+
end
104129
io = IOBuffer()
105130
Base.print_array(io, unassigned_var)
106131
unassigned_var_str = String(take!(io))
@@ -110,7 +135,7 @@ function check_consistency(state::TransformationState, orig_inputs)
110135
throw(InvalidSystemException(errmsg))
111136
end
112137

113-
return nothing
138+
return true
114139
end
115140

116141
###

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3335,7 +3335,8 @@ function parse_variable(sys::AbstractSystem, str::AbstractString)
33353335
# I'd write a regex to validate `str`, but https://xkcd.com/1171/
33363336
str = strip(str)
33373337
derivative_level = 0
3338-
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) && endswith(str, ")")
3338+
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) &&
3339+
endswith(str, ")")
33393340
if cond1
33403341
derivative_level += 1
33413342
str = _string_view_inner(str, 2, 1)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
590590
checkbounds = false,
591591
initializeprob = nothing,
592592
initializeprobmap = nothing,
593+
initializeprobpmap = nothing,
594+
update_initializeprob! = nothing,
593595
kwargs...) where {iip}
594596
if !iscomplete(sys)
595597
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
@@ -643,7 +645,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
643645
jac_prototype = jac_prototype,
644646
observed = observedfun,
645647
initializeprob = initializeprob,
646-
initializeprobmap = initializeprobmap)
648+
initializeprobmap = initializeprobmap,
649+
initializeprobpmap = initializeprobpmap,
650+
update_initializeprob! = update_initializeprob!)
647651
end
648652

649653
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -1387,7 +1391,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
13871391
check_length = true,
13881392
warn_initialize_determined = true,
13891393
initialization_eqs = [],
1390-
fully_determined = false,
1394+
fully_determined = nothing,
13911395
check_units = true,
13921396
kwargs...) where {iip, specialize}
13931397
if !iscomplete(sys)
@@ -1405,6 +1409,19 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
14051409
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
14061410
end
14071411

1412+
ts = get_tearing_state(isys)
1413+
if warn_initialize_determined &&
1414+
(unassigned_vars = StructuralTransformations.singular_check(ts); !isempty(unassigned_vars))
1415+
errmsg = """
1416+
The initialization system is structurally singular. Guess values may \
1417+
significantly affect the initial values of the ODE. The problematic variables \
1418+
are $unassigned_vars.
1419+
1420+
Note that the identification of problematic variables is a best-effort heuristic.
1421+
"""
1422+
@warn errmsg
1423+
end
1424+
14081425
uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])
14091426

14101427
# TODO: throw on uninitialized arrays
@@ -1448,6 +1465,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
14481465
u0T = promote_type(u0T, typeof(fullmap[eq.lhs]))
14491466
end
14501467
if u0T != Union{}
1468+
u0T = eltype(u0T)
14511469
u0map = Dict(k => if symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)
14521470
v isa AbstractArray ? u0T.(v) : u0T(v)
14531471
else

src/systems/index_cache.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ const TunableIndexMap = Dict{BasicSymbolic,
4040
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
4141
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}
4242

43+
const SymbolicParam = Union{BasicSymbolic, CallWithMetadata}
44+
4345
struct IndexCache
4446
unknown_idx::UnknownIndexMap
4547
# sym => (bufferidx, idx_in_buffer)
46-
discrete_idx::Dict{BasicSymbolic, DiscreteIndex}
48+
discrete_idx::Dict{SymbolicParam, DiscreteIndex}
4749
# sym => (clockidx, idx_in_clockbuffer)
4850
callback_to_clocks::Dict{Any, Vector{Int}}
4951
tunable_idx::TunableIndexMap
@@ -56,13 +58,13 @@ struct IndexCache
5658
tunable_buffer_size::BufferTemplate
5759
constant_buffer_sizes::Vector{BufferTemplate}
5860
nonnumeric_buffer_sizes::Vector{BufferTemplate}
59-
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
61+
symbol_to_variable::Dict{Symbol, SymbolicParam}
6062
end
6163

6264
function IndexCache(sys::AbstractSystem)
6365
unks = solved_unknowns(sys)
6466
unk_idxs = UnknownIndexMap()
65-
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()
67+
symbol_to_variable = Dict{Symbol, SymbolicParam}()
6668

6769
let idx = 1
6870
for sym in unks
@@ -95,18 +97,18 @@ function IndexCache(sys::AbstractSystem)
9597

9698
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
9799
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
98-
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
100+
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()
99101

100102
function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
101103
sym = unwrap(sym)
102104
buf = get!(buffers, ctype, S())
103105
push!(buf, sym)
104106
end
105107

106-
disc_param_callbacks = Dict{BasicSymbolic, Set{Int}}()
108+
disc_param_callbacks = Dict{SymbolicParam, Set{Int}}()
107109
events = vcat(continuous_events(sys), discrete_events(sys))
108110
for (i, event) in enumerate(events)
109-
discs = Set{BasicSymbolic}()
111+
discs = Set{SymbolicParam}()
110112
affs = affects(event)
111113
if !(affs isa AbstractArray)
112114
affs = [affs]
@@ -130,26 +132,32 @@ function IndexCache(sys::AbstractSystem)
130132
isequal(only(arguments(sym)), get_iv(sys))
131133
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
132134
push!(clocks, i)
133-
else
135+
elseif is_variable_floatingpoint(sym)
134136
insert_by_type!(constant_buffers, sym, symtype(sym))
137+
else
138+
stype = symtype(sym)
139+
if stype <: FnType
140+
stype = fntype_to_function_type(stype)
141+
end
142+
insert_by_type!(nonnumeric_buffers, sym, stype)
135143
end
136144
end
137145
end
138146
clock_partitions = unique(collect(values(disc_param_callbacks)))
139147
disc_symtypes = unique(symtype.(keys(disc_param_callbacks)))
140148
disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes))
141-
disc_syms_by_symtype = [BasicSymbolic[] for _ in disc_symtypes]
149+
disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes]
142150
for sym in keys(disc_param_callbacks)
143151
push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym)
144152
end
145-
disc_syms_by_symtype_by_partition = [Vector{BasicSymbolic}[] for _ in disc_symtypes]
153+
disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes]
146154
for (i, buffer) in enumerate(disc_syms_by_symtype)
147155
for partition in clock_partitions
148156
push!(disc_syms_by_symtype_by_partition[i],
149157
[sym for sym in buffer if disc_param_callbacks[sym] == partition])
150158
end
151159
end
152-
disc_idxs = Dict{BasicSymbolic, DiscreteIndex}()
160+
disc_idxs = Dict{SymbolicParam, DiscreteIndex}()
153161
callback_to_clocks = Dict{
154162
Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}()
155163
for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition)
@@ -191,6 +199,7 @@ function IndexCache(sys::AbstractSystem)
191199
end
192200
haskey(disc_idxs, p) && continue
193201
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
202+
haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue
194203
insert_by_type!(
195204
if ctype <: Real || ctype <: AbstractArray{<:Real}
196205
if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() &&

src/systems/model_parsing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
923923
end
924924

925925
push!(varexpr.args, metadata_expr)
926-
return vv isa Num ? name : :($name...), varexpr
926+
return symbolic_type(vv) == ScalarSymbolic() ? name : :($name...), varexpr
927927
else
928928
return vv
929929
end

0 commit comments

Comments
 (0)