Skip to content

Commit 9b0222f

Browse files
Merge branch 'master' into iss3707
2 parents 8d8d057 + 68f5e73 commit 9b0222f

21 files changed

+430
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
4-
version = "10.19.0"
4+
version = "10.21.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/API/model_building.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ add_accumulations
229229
noise_to_brownians
230230
convert_system_indepvar
231231
subset_tunables
232+
respecialize
232233
```
233234

234235
## Hybrid systems

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
269269
subset_tunables
270270
export liouville_transform, change_independent_variable, substitute_component,
271271
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables,
272-
fractional_to_ordinary, linear_fractional_to_ordinary
272+
fractional_to_ordinary, linear_fractional_to_ordinary,
273+
export respecialize
273274
export PDESystem
274275
export Differential, expand_derivatives, @derivatives
275276
export Equation, ConstrainedEquation

src/inputoutput.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Symbolics: get_variables
55
Return all variables that mare marked as inputs. See also [`unbound_inputs`](@ref)
66
See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref)
77
"""
8-
inputs(sys) = [filter(isinput, unknowns(sys)); filter(isinput, parameters(sys))]
8+
inputs(sys) = collect(get_inputs(sys))
99

1010
"""
1111
outputs(sys)
@@ -14,13 +14,7 @@ Return all variables that mare marked as outputs. See also [`unbound_outputs`](@
1414
See also [`bound_outputs`](@ref), [`unbound_outputs`](@ref)
1515
"""
1616
function outputs(sys)
17-
o = observed(sys)
18-
rhss = [eq.rhs for eq in o]
19-
lhss = [eq.lhs for eq in o]
20-
unique([filter(isoutput, unknowns(sys))
21-
filter(isoutput, parameters(sys))
22-
filter(x -> iscall(x) && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms
23-
filter(x -> iscall(x) && isoutput(x), lhss)])
17+
return collect(get_outputs(sys))
2418
end
2519

2620
"""
@@ -288,7 +282,12 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
288282
push!(new_fullvars, v)
289283
end
290284
end
291-
ninputs == 0 && return state
285+
if ninputs == 0
286+
@set! sys.inputs = OrderedSet{BasicSymbolic}()
287+
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
288+
state.sys = sys
289+
return state
290+
end
292291

293292
nvars = ndsts(graph) - ninputs
294293
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
@@ -318,6 +317,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
318317
ps = parameters(sys)
319318

320319
@set! sys.ps = [ps; new_parameters]
320+
@set! sys.inputs = OrderedSet{BasicSymbolic}(new_parameters)
321+
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
321322
@set! state.sys = sys
322323
@set! state.fullvars = Vector{BasicSymbolic}(new_fullvars)
323324
@set! state.structure = structure

src/linearization.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,7 @@ struct IONotFoundError <: Exception
572572
end
573573

574574
function Base.showerror(io::IO, err::IONotFoundError)
575-
println(io,
576-
"The following $(err.variant) provided to `mtkcompile` were not found in the system:")
575+
println(io, "The following $(err.variant) provided to `mtkcompile` were not found in the system:")
577576
maybe_namespace_issue = false
578577
for var in err.not_found
579578
println(io, " ", var)

src/systems/abstractsystem.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,8 @@ const SYS_PROPS = [:eqs
784784
:parent
785785
:is_dde
786786
:tstops
787+
:inputs
788+
:outputs
787789
:index_cache
788790
:isscheduled
789791
:costs
@@ -1820,6 +1822,17 @@ function push_vars!(stmt, name, typ, vars)
18201822
ex = nameof(s)
18211823
end
18221824
push!(vars_expr.args, ex)
1825+
1826+
meta_kvps = Expr[]
1827+
if isinput(s)
1828+
push!(meta_kvps, :(input = true))
1829+
end
1830+
if isoutput(s)
1831+
push!(meta_kvps, :(output = true))
1832+
end
1833+
if !isempty(meta_kvps)
1834+
push!(vars_expr.args, Expr(:vect, meta_kvps...))
1835+
end
18231836
end
18241837
push!(stmt, :($name = $collect($vars_expr)))
18251838
return

src/systems/callbacks.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...)
2525
end
2626
SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...)
2727

28+
function Symbolics.fast_substitute(aff::SymbolicAffect, rules)
29+
substituter = Base.Fix2(fast_substitute, rules)
30+
SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs),
31+
map(substituter, aff.discrete_parameters))
32+
end
33+
2834
struct AffectSystem
2935
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
3036
system::AbstractSystem
@@ -36,6 +42,19 @@ struct AffectSystem
3642
discretes::Vector
3743
end
3844

45+
function Symbolics.fast_substitute(aff::AffectSystem, rules)
46+
substituter = Base.Fix2(fast_substitute, rules)
47+
sys = aff.system
48+
@set! sys.eqs = map(substituter, get_eqs(sys))
49+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
50+
@set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)])
51+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
52+
@set! sys.unknowns = map(substituter, get_unknowns(sys))
53+
@set! sys.ps = map(substituter, get_ps(sys))
54+
AffectSystem(sys, map(substituter, aff.unknowns),
55+
map(substituter, aff.parameters), map(substituter, aff.discretes))
56+
end
57+
3958
function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
4059
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
4160
discrete_parameters = spec.discrete_parameters, kwargs...)
@@ -855,7 +874,7 @@ end
855874
function default_operating_point(affsys::AffectSystem)
856875
sys = system(affsys)
857876

858-
op = Dict(unknowns(sys) .=> 0.0)
877+
op = AnyDict(unknowns(sys) .=> 0.0)
859878
for p in parameters(sys)
860879
T = symtype(p)
861880
if T <: Number

src/systems/connectors.jl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ function generate_connection_equations_and_stream_connections(
759759
var = variable_from_vertex(sys, cvert)::BasicSymbolic
760760
vtype = cvert.type
761761
if vtype <: Union{InputVar, OutputVar}
762+
length(cset) > 1 || continue
762763
inner_output = nothing
763764
outer_input = nothing
764765
for cvert in cset
@@ -780,11 +781,11 @@ function generate_connection_equations_and_stream_connections(
780781
inner_output = cvert
781782
end
782783
end
783-
root, rest = Iterators.peel(cset)
784-
root_var = variable_from_vertex(sys, root)
785-
for cvert in rest
786-
var = variable_from_vertex(sys, cvert)
787-
push!(eqs, root_var ~ var)
784+
root_vert = something(inner_output, outer_input)
785+
root_var = variable_from_vertex(sys, root_vert)
786+
for cvert in cset
787+
isequal(cvert, root_vert) && continue
788+
push!(eqs, variable_from_vertex(sys, cvert) ~ root_var)
788789
end
789790
elseif vtype === Stream
790791
push!(stream_connections, cset)
@@ -807,10 +808,37 @@ function generate_connection_equations_and_stream_connections(
807808
push!(eqs, 0 ~ rhs)
808809
end
809810
else # Equality
810-
base = variable_from_vertex(sys, cset[1])
811-
for i in 2:length(cset)
812-
v = variable_from_vertex(sys, cset[i])
813-
push!(eqs, base ~ v)
811+
vars = map(Base.Fix1(variable_from_vertex, sys), cset)
812+
outer_input = inner_output = nothing
813+
all_io = true
814+
# attempt to interpret the equality as a causal connectionset if
815+
# possible
816+
for (cvert, vert) in zip(cset, vars)
817+
is_i = isinput(vert)
818+
is_o = isoutput(vert)
819+
all_io &= is_i || is_o
820+
all_io || break
821+
if cvert.isouter && is_i && outer_input === nothing
822+
outer_input = cvert
823+
elseif !cvert.isouter && is_o && inner_output === nothing
824+
inner_output = cvert
825+
end
826+
end
827+
# this doesn't necessarily mean this is a well-structured causal connection,
828+
# but it is sufficient and we're generating equalities anyway.
829+
if all_io && xor(outer_input !== nothing, inner_output !== nothing)
830+
root_vert = something(inner_output, outer_input)
831+
root_var = variable_from_vertex(sys, root_vert)
832+
for (cvert, var) in zip(cset, vars)
833+
isequal(cvert, root_vert) && continue
834+
push!(eqs, var ~ root_var)
835+
end
836+
else
837+
base = variable_from_vertex(sys, cset[1])
838+
for i in 2:length(cset)
839+
v = vars[i]
840+
push!(eqs, base ~ v)
841+
end
814842
end
815843
end
816844
end

src/systems/diffeqs/basic_transformations.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,3 +942,143 @@ function convert_system_indepvar(sys::System, t; name = nameof(sys))
942942
@set! sys.var_to_name = var_to_name
943943
return sys
944944
end
945+
946+
"""
947+
$(TYPEDSIGNATURES)
948+
949+
Shorthand for `respecialize(sys, []; all = true)`
950+
"""
951+
respecialize(sys::AbstractSystem) = respecialize(sys, []; all = true)
952+
953+
"""
954+
$(TYPEDSIGNATURES)
955+
956+
Specialize nonnumeric parameters in `sys` by changing their symtype to a concrete type.
957+
`mapping` is an iterable, where each element can be a parameter or a pair mapping a parameter
958+
to a value. If the element is a parameter, it must have a default. Each specified parameter
959+
is updated to have the symtype of the value associated with it (either in `mapping` or in
960+
the defaults). This operation can only be performed on nonnumeric, non-array parameters. The
961+
defaults of respecialized parameters are set to the associated values.
962+
963+
This operation can only be performed on `complete`d systems.
964+
965+
# Keyword arguments
966+
967+
- `all`: Specialize all nonnumeric parameters in the system. This will error if any such
968+
parameter does not have a default.
969+
"""
970+
function respecialize(sys::AbstractSystem, mapping; all = false)
971+
if !iscomplete(sys)
972+
error("""
973+
This operation can only be performed on completed systems. Use `complete(sys)` or
974+
`mtkcompile(sys)`.
975+
""")
976+
end
977+
if !is_split(sys)
978+
error("""
979+
This operation can only be performed on split systems. Use `complete(sys)` or
980+
`mtkcompile(sys)` with the `split = true` keyword argument.
981+
""")
982+
end
983+
984+
new_ps = copy(get_ps(sys))
985+
@set! sys.ps = new_ps
986+
987+
extras = []
988+
if all
989+
for x in filter(!is_variable_numeric, get_ps(sys))
990+
if any(y -> isequal(x, y) || y isa Pair && isequal(x, y[1]), mapping) ||
991+
symbolic_type(x) === ArraySymbolic() ||
992+
iscall(x) && operation(x) === getindex
993+
continue
994+
end
995+
push!(extras, x)
996+
end
997+
end
998+
ps_to_specialize = Iterators.flatten((extras, mapping))
999+
1000+
defs = copy(defaults(sys))
1001+
@set! sys.defaults = defs
1002+
final_defs = copy(defs)
1003+
evaluate_varmap!(final_defs, ps_to_specialize)
1004+
1005+
subrules = Dict()
1006+
1007+
for element in ps_to_specialize
1008+
if element isa Pair
1009+
k, v = element
1010+
else
1011+
k = element
1012+
v = get(final_defs, k, nothing)
1013+
@assert v !== nothing """
1014+
Parameter $k needs an associated value to be respecialized.
1015+
"""
1016+
@assert symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) """
1017+
Parameter $k needs an associated value to be respecialized. Found symbolic \
1018+
default $v.
1019+
"""
1020+
end
1021+
1022+
k = unwrap(k)
1023+
T = typeof(v)
1024+
1025+
@assert !is_variable_numeric(k) """
1026+
Numeric types cannot be respecialized - tried to respecialize $k.
1027+
"""
1028+
@assert symbolic_type(k) !== ArraySymbolic() """
1029+
Cannot respecialize array symbolics - tried to respecialize $k.
1030+
"""
1031+
@assert !iscall(k) || operation(k) !== getindex """
1032+
Cannot respecialized scalarized array variables - tried to respecialize $k.
1033+
"""
1034+
idx = findfirst(isequal(k), get_ps(sys))
1035+
@assert idx !== nothing """
1036+
Parameter $k does not exist in the system.
1037+
"""
1038+
1039+
if iscall(k)
1040+
op = operation(k)
1041+
args = arguments(k)
1042+
new_p = SymbolicUtils.term(op, args...; type = T)
1043+
else
1044+
new_p = SymbolicUtils.Sym{T}(getname(k))
1045+
end
1046+
1047+
get_ps(sys)[idx] = new_p
1048+
defaults(sys)[new_p] = v
1049+
subrules[unwrap(k)] = unwrap(new_p)
1050+
end
1051+
1052+
substituter = Base.Fix2(fast_substitute, subrules)
1053+
@set! sys.eqs = map(substituter, get_eqs(sys))
1054+
@set! sys.observed = map(substituter, get_observed(sys))
1055+
@set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys))
1056+
if get_noise_eqs(sys) !== nothing
1057+
@set! sys.noise_eqs = map(substituter, get_noise_eqs(sys))
1058+
end
1059+
@set! sys.assertions = Dict([substituter(k) => v for (k, v) in assertions(sys)])
1060+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
1061+
@set! sys.defaults = Dict([substituter(k) => substituter(v) for (k, v) in defaults(sys)])
1062+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
1063+
@set! sys.continuous_events = map(get_continuous_events(sys)) do cev
1064+
SymbolicContinuousCallback(
1065+
map(substituter, cev.conditions), substituter(cev.affect),
1066+
substituter(cev.affect_neg), substituter(cev.initialize),
1067+
substituter(cev.finalize), cev.rootfind,
1068+
cev.reinitializealg, cev.zero_crossing_id)
1069+
end
1070+
@set! sys.discrete_events = map(get_discrete_events(sys)) do dev
1071+
SymbolicDiscreteCallback(map(substituter, dev.conditions), substituter(dev.affect),
1072+
substituter(dev.initialize), substituter(dev.finalize), dev.reinitializealg)
1073+
end
1074+
if get_schedule(sys) !== nothing
1075+
sched = get_schedule(sys)
1076+
@set! sys.schedule = Schedule(
1077+
sched.var_sccs, AnyDict(k => substituter(v) for (k, v) in sched.dummy_sub))
1078+
end
1079+
@set! sys.constraints = map(substituter, get_constraints(sys))
1080+
@set! sys.tstops = map(substituter, get_tstops(sys))
1081+
@set! sys.costs = Vector{Union{Real, BasicSymbolic}}(map(substituter, get_costs(sys)))
1082+
sys = complete(sys; split = is_split(sys))
1083+
return sys
1084+
end

src/systems/imperative_affect.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ function ImperativeAffect(; f, kwargs...)
6767
ImperativeAffect(f; kwargs...)
6868
end
6969

70+
function Symbolics.fast_substitute(aff::ImperativeAffect, rules)
71+
substituter = Base.Fix2(fast_substitute, rules)
72+
ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms,
73+
map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks)
74+
end
75+
7076
function Base.show(io::IO, mfa::ImperativeAffect)
7177
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
7278
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")

0 commit comments

Comments
 (0)