Skip to content

Commit 1e5fce8

Browse files
committed
Merge remote-tracking branch 'origin' into implicit_discrete_system
Update to master
2 parents d40fd6e + 157966e commit 1e5fce8

25 files changed

+771
-96
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
4-
version = "9.64.3"
4+
version = "9.65.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -136,7 +136,7 @@ Reexport = "0.2, 1"
136136
RuntimeGeneratedFunctions = "0.5.9"
137137
SCCNonlinearSolve = "1.0.0"
138138
SciMLBase = "2.75"
139-
SciMLStructures = "1.0"
139+
SciMLStructures = "1.7"
140140
Serialization = "1"
141141
Setfield = "0.7, 0.8, 1"
142142
SimpleNonlinearSolve = "0.1.0, 1, 2"

docs/src/tutorials/initialization.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,14 @@ sol[α * x - β * x * y]
536536
```@example init
537537
plot(sol)
538538
```
539+
540+
## Summary of Initialization API
541+
542+
```@docs; canonical=false
543+
Initial
544+
isinitial
545+
generate_initializesystem
546+
initialization_equations
547+
guesses
548+
defaults
549+
```

ext/MTKChainRulesCoreExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ function ChainRulesCore.rrule(
7979
end
8080
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
8181
tunable_idxs = reduce(
82-
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable))
82+
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable);
83+
init = Union{Int, AbstractVector{Int}}[])
84+
initials_idxs = reduce(
85+
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials);
86+
init = Union{Int, AbstractVector{Int}}[])
8387
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
8488
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
8589
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
@@ -91,10 +95,12 @@ function ChainRulesCore.rrule(
9195
indp′ = NoTangent()
9296

9397
tunable = selected_tangents(buf′.tunable, tunable_idxs)
98+
initials = selected_tangents(buf′.initials, initials_idxs)
9499
discrete = selected_tangents(buf′.discrete, disc_idxs)
95100
constant = selected_tangents(buf′.constant, const_idxs)
96101
nonnumeric = selected_tangents(buf′.nonnumeric, nn_idxs)
97-
oldbuf′ = Tangent{typeof(oldbuf)}(; tunable, discrete, constant, nonnumeric)
102+
oldbuf′ = Tangent{typeof(oldbuf)}(;
103+
tunable, initials, discrete, constant, nonnumeric)
98104
idxs′ = NoTangent()
99105
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
100106
return f′, indp′, oldbuf′, idxs′, vals′

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ export toexpr, get_variables
290290
export simplify, substitute
291291
export build_function
292292
export modelingtoolkitize
293-
export generate_initializesystem, Initial
293+
export generate_initializesystem, Initial, isinitial
294294

295295
export alg_equations, diff_equations, has_alg_equations, has_diff_equations
296296
export get_alg_eqs, get_diff_eqs, has_alg_eqs, has_diff_eqs

src/inputoutput.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
250250
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
251251
p_end = length(p) + 2 + implicit_dae)
252252
f = eval_or_rgf.(f; eval_expression, eval_module)
253+
f = GeneratedFunctionWrapper{(
254+
3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
255+
f = f, f
253256
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
254257
(; f, dvs, ps, io_sys = sys)
255258
end

src/systems/abstractsystem.jl

Lines changed: 132 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,22 @@ function add_initialization_parameters(sys::AbstractSystem)
717717
return sys
718718
end
719719

720+
"""
721+
Returns true if the parameter `p` is of the form `Initial(x)`.
722+
"""
723+
function isinitial(p)
724+
p = unwrap(p)
725+
if iscall(p)
726+
operation(p) isa Initial && return true
727+
if operation(p) === getindex
728+
operation(arguments(p)[1]) isa Initial && return true
729+
end
730+
else
731+
return false
732+
end
733+
return false
734+
end
735+
720736
"""
721737
$(TYPEDSIGNATURES)
722738
@@ -757,38 +773,21 @@ function complete(sys::AbstractSystem; split = true, flatten = true)
757773
if !isempty(all_ps)
758774
# reorder parameters by portions
759775
ps_split = reorder_parameters(sys, all_ps)
776+
# if there are tunables, they will all be in `ps_split[1]`
777+
# and the arrays will have been scalarized
778+
ordered_ps = eltype(all_ps)[]
760779
# if there are no tunables, vcat them
761-
if isempty(get_index_cache(sys).tunable_idx)
762-
ordered_ps = reduce(vcat, ps_split)
763-
else
764-
# if there are tunables, they will all be in `ps_split[1]`
765-
# and the arrays will have been scalarized
766-
ordered_ps = eltype(all_ps)[]
767-
i = 1
768-
# go through all the tunables
769-
while i <= length(ps_split[1])
770-
sym = ps_split[1][i]
771-
# if the sym is not a scalarized array symbolic OR it was already scalarized,
772-
# just push it as-is
773-
if !iscall(sym) || operation(sym) != getindex ||
774-
any(isequal(sym), all_ps)
775-
push!(ordered_ps, sym)
776-
i += 1
777-
continue
778-
end
779-
# the next `length(sym)` symbols should be scalarized versions of the same
780-
# array symbolic
781-
if !allequal(first(arguments(x))
782-
for x in view(ps_split[1], i:(i + length(sym) - 1)))
783-
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
784-
end
785-
arrsym = first(arguments(sym))
786-
push!(ordered_ps, arrsym)
787-
i += length(arrsym)
788-
end
789-
ordered_ps = vcat(
790-
ordered_ps, reduce(vcat, ps_split[2:end]; init = eltype(ordered_ps)[]))
780+
if !isempty(get_index_cache(sys).tunable_idx)
781+
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
782+
ps_split = Base.tail(ps_split)
783+
end
784+
# unflatten initial parameters
785+
if !isempty(get_index_cache(sys).initials_idx)
786+
unflatten_parameters!(ordered_ps, ps_split[1], all_ps)
787+
ps_split = Base.tail(ps_split)
791788
end
789+
ordered_ps = vcat(
790+
ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[]))
792791
@set! sys.ps = ordered_ps
793792
end
794793
elseif has_index_cache(sys)
@@ -800,6 +799,39 @@ function complete(sys::AbstractSystem; split = true, flatten = true)
800799
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
801800
end
802801

802+
"""
803+
$(TYPEDSIGNATURES)
804+
805+
Given a flattened array of parameters `params` and a collection of all (unscalarized)
806+
parameters in the system `all_ps`, unscalarize the elements in `params` and append
807+
to `buffer` in the same order as they are present in `params`. Effectively, if
808+
`params = [p[1], p[2], p[3], q]` then this is equivalent to `push!(buffer, p, q)`.
809+
"""
810+
function unflatten_parameters!(buffer, params, all_ps)
811+
i = 1
812+
# go through all the tunables
813+
while i <= length(params)
814+
sym = params[i]
815+
# if the sym is not a scalarized array symbolic OR it was already scalarized,
816+
# just push it as-is
817+
if !iscall(sym) || operation(sym) != getindex ||
818+
any(isequal(sym), all_ps)
819+
push!(buffer, sym)
820+
i += 1
821+
continue
822+
end
823+
# the next `length(sym)` symbols should be scalarized versions of the same
824+
# array symbolic
825+
if !allequal(first(arguments(x))
826+
for x in view(params, i:(i + length(sym) - 1)))
827+
error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.")
828+
end
829+
arrsym = first(arguments(sym))
830+
push!(buffer, arrsym)
831+
i += length(arrsym)
832+
end
833+
end
834+
803835
for prop in [:eqs
804836
:tag
805837
:noiseeqs
@@ -846,6 +878,7 @@ for prop in [:eqs
846878
:assertions
847879
:solved_unknowns
848880
:split_idxs
881+
:ignored_connections
849882
:parent
850883
:is_dde
851884
:tstops
@@ -1362,6 +1395,75 @@ function assertions(sys::AbstractSystem)
13621395
return merge(asserts, namespaced_asserts)
13631396
end
13641397

1398+
const HierarchyVariableT = Vector{Union{BasicSymbolic, Symbol}}
1399+
const HierarchySystemT = Vector{Union{AbstractSystem, Symbol}}
1400+
"""
1401+
The type returned from `as_hierarchy`.
1402+
"""
1403+
const HierarchyT = Union{HierarchyVariableT, HierarchySystemT}
1404+
1405+
"""
1406+
$(TYPEDSIGNATURES)
1407+
1408+
The inverse operation of `as_hierarchy`.
1409+
"""
1410+
function from_hierarchy(hierarchy::HierarchyT)
1411+
namefn = hierarchy[1] isa AbstractSystem ? nameof : getname
1412+
foldl(@view hierarchy[2:end]; init = hierarchy[1]) do sys, name
1413+
rename(sys, Symbol(name, NAMESPACE_SEPARATOR, namefn(sys)))
1414+
end
1415+
end
1416+
1417+
"""
1418+
$(TYPEDSIGNATURES)
1419+
1420+
Represent a namespaced system (or variable) `sys` as a hierarchy. Return a vector, where
1421+
the first element is the unnamespaced system (variable) and subsequent elements are
1422+
`Symbol`s representing the parents of the unnamespaced system (variable) in order from
1423+
inner to outer.
1424+
"""
1425+
function as_hierarchy(sys::Union{AbstractSystem, BasicSymbolic})::HierarchyT
1426+
namefn = sys isa AbstractSystem ? nameof : getname
1427+
# get the hierarchy
1428+
hierarchy = namespace_hierarchy(namefn(sys))
1429+
# rename the system with unnamespaced name
1430+
newsys = rename(sys, hierarchy[end])
1431+
# and remove it from the list
1432+
pop!(hierarchy)
1433+
# reverse it to go from inner to outer
1434+
reverse!(hierarchy)
1435+
# concatenate
1436+
T = sys isa AbstractSystem ? AbstractSystem : BasicSymbolic
1437+
return Union{Symbol, T}[newsys; hierarchy]
1438+
end
1439+
1440+
"""
1441+
$(TYPEDSIGNATURES)
1442+
1443+
Get the connections to ignore for `sys` and its subsystems. The returned value is a
1444+
`Tuple` similar in structure to the `ignored_connections` field. Each system (variable)
1445+
in the first (second) element of the tuple is also passed through `as_hierarchy`.
1446+
"""
1447+
function ignored_connections(sys::AbstractSystem)
1448+
has_ignored_connections(sys) || return (HierarchySystemT[], HierarchyVariableT[])
1449+
1450+
ics = get_ignored_connections(sys)
1451+
if ics === nothing
1452+
ics = (HierarchySystemT[], HierarchyVariableT[])
1453+
end
1454+
# turn into hierarchies
1455+
ics = (map(as_hierarchy, ics[1]), map(as_hierarchy, ics[2]))
1456+
systems = get_systems(sys)
1457+
# for each subsystem, get its ignored connections, add the name of the subsystem
1458+
# to the hierarchy and concatenate corresponding buffers of the result
1459+
result = mapreduce(Broadcast.BroadcastFunction(vcat), systems; init = ics) do subsys
1460+
sub_ics = ignored_connections(subsys)
1461+
(map(Base.Fix2(push!, nameof(subsys)), sub_ics[1]),
1462+
map(Base.Fix2(push!, nameof(subsys)), sub_ics[2]))
1463+
end
1464+
return (Vector{HierarchySystemT}(result[1]), Vector{HierarchyVariableT}(result[2]))
1465+
end
1466+
13651467
"""
13661468
$(TYPEDSIGNATURES)
13671469

src/systems/analysis_points.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,27 @@ function get_analysis_variable(var, name, iv; perturb = true)
412412
return pvar, default
413413
end
414414

415+
function with_analysis_point_ignored(sys::AbstractSystem, ap::AnalysisPoint)
416+
has_ignored_connections(sys) || return sys
417+
ignored = get_ignored_connections(sys)
418+
if ignored === nothing
419+
ignored = (ODESystem[], BasicSymbolic[])
420+
else
421+
ignored = copy.(ignored)
422+
end
423+
if ap.outputs === nothing
424+
error("Empty analysis point")
425+
end
426+
for x in ap.outputs
427+
if x isa ODESystem
428+
push!(ignored[1], x)
429+
else
430+
push!(ignored[2], unwrap(x))
431+
end
432+
end
433+
return @set sys.ignored_connections = ignored
434+
end
435+
415436
#### PRIMITIVE TRANSFORMATIONS
416437

417438
const DOC_WILL_REMOVE_AP = """
@@ -469,7 +490,9 @@ function apply_transformation(tf::Break, sys::AbstractSystem)
469490
ap = breaksys_eqs[ap_idx].rhs
470491
deleteat!(breaksys_eqs, ap_idx)
471492

472-
tf.add_input || return sys, ()
493+
breaksys = with_analysis_point_ignored(breaksys, ap)
494+
495+
tf.add_input || return breaksys, ()
473496

474497
ap_ivar = ap_var(ap.input)
475498
new_var, new_def = get_analysis_variable(ap_ivar, nameof(ap), get_iv(sys))
@@ -511,7 +534,7 @@ function apply_transformation(tf::GetInput, sys::AbstractSystem)
511534
ap_idx === nothing &&
512535
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
513536
# get the anlysis point
514-
ap_sys_eqs = copy(get_eqs(ap_sys))
537+
ap_sys_eqs = get_eqs(ap_sys)
515538
ap = ap_sys_eqs[ap_idx].rhs
516539

517540
# input variable
@@ -570,6 +593,7 @@ function apply_transformation(tf::PerturbOutput, sys::AbstractSystem)
570593
ap = ap_sys_eqs[ap_idx].rhs
571594
# remove analysis point
572595
deleteat!(ap_sys_eqs, ap_idx)
596+
ap_sys = with_analysis_point_ignored(ap_sys, ap)
573597

574598
# add equations involving new variable
575599
ap_ivar = ap_var(ap.input)
@@ -634,7 +658,7 @@ function apply_transformation(tf::AddVariable, sys::AbstractSystem)
634658
ap_idx = analysis_point_index(ap_sys, tf.ap)
635659
ap_idx === nothing &&
636660
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
637-
ap_sys_eqs = copy(get_eqs(ap_sys))
661+
ap_sys_eqs = get_eqs(ap_sys)
638662
ap = ap_sys_eqs[ap_idx].rhs
639663

640664
# add equations involving new variable

src/systems/codegen_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ end
287287
# The user provided a single buffer/tuple for the parameter object, so wrap that
288288
# one in a tuple
289289
fargs = ntuple(Val(length(args))) do i
290-
i == paramidx ? :((args[$i],)) : :(args[$i])
290+
i == paramidx ? :((args[$i], nothing)) : :(args[$i])
291291
end
292292
return :($f($(fargs...)))
293293
end

0 commit comments

Comments
 (0)