Skip to content

Commit a0d3ab8

Browse files
Merge branch 'SciML:master' into dg/metamacro
2 parents 1a46ff6 + 68f5e73 commit a0d3ab8

23 files changed

+440
-47
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
4-
version = "10.19.0"
4+
version = "10.21.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/API/model_building.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ add_accumulations
227227
noise_to_brownians
228228
convert_system_indepvar
229229
subset_tunables
230+
respecialize
230231
```
231232

232233
## Hybrid systems

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
268268
hasmisc, getmisc, state_priority,
269269
subset_tunables
270270
export liouville_transform, change_independent_variable, substitute_component,
271-
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables
271+
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables,
272+
respecialize
272273
export PDESystem
273274
export Differential, expand_derivatives, @derivatives
274275
export Equation, ConstrainedEquation

src/inputoutput.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Symbolics: get_variables
55
Return all variables that mare marked as inputs. See also [`unbound_inputs`](@ref)
66
See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref)
77
"""
8-
inputs(sys) = [filter(isinput, unknowns(sys)); filter(isinput, parameters(sys))]
8+
inputs(sys) = collect(get_inputs(sys))
99

1010
"""
1111
outputs(sys)
@@ -14,13 +14,7 @@ Return all variables that mare marked as outputs. See also [`unbound_outputs`](@
1414
See also [`bound_outputs`](@ref), [`unbound_outputs`](@ref)
1515
"""
1616
function outputs(sys)
17-
o = observed(sys)
18-
rhss = [eq.rhs for eq in o]
19-
lhss = [eq.lhs for eq in o]
20-
unique([filter(isoutput, unknowns(sys))
21-
filter(isoutput, parameters(sys))
22-
filter(x -> iscall(x) && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms
23-
filter(x -> iscall(x) && isoutput(x), lhss)])
17+
return collect(get_outputs(sys))
2418
end
2519

2620
"""
@@ -288,7 +282,12 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
288282
push!(new_fullvars, v)
289283
end
290284
end
291-
ninputs == 0 && return state
285+
if ninputs == 0
286+
@set! sys.inputs = OrderedSet{BasicSymbolic}()
287+
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
288+
state.sys = sys
289+
return state
290+
end
292291

293292
nvars = ndsts(graph) - ninputs
294293
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
@@ -318,6 +317,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
318317
ps = parameters(sys)
319318

320319
@set! sys.ps = [ps; new_parameters]
320+
@set! sys.inputs = OrderedSet{BasicSymbolic}(new_parameters)
321+
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
321322
@set! state.sys = sys
322323
@set! state.fullvars = Vector{BasicSymbolic}(new_fullvars)
323324
@set! state.structure = structure

src/linearization.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,7 @@ struct IONotFoundError <: Exception
572572
end
573573

574574
function Base.showerror(io::IO, err::IONotFoundError)
575-
println(io,
576-
"The following $(err.variant) provided to `mtkcompile` were not found in the system:")
575+
println(io, "The following $(err.variant) provided to `mtkcompile` were not found in the system:")
577576
maybe_namespace_issue = false
578577
for var in err.not_found
579578
println(io, " ", var)

src/problems/odeproblem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false,
44
steady_state = false, checkbounds = false, sparsity = false, analytic = nothing,
55
simplify = false, cse = true, initialization_data = nothing, expression = Val{false},
6-
check_compatibility = true, nlstep = false, nlstep_compile = true, kwargs...) where {iip, spec}
6+
check_compatibility = true, nlstep = false, nlstep_compile = true, nlstep_scc = false,
7+
kwargs...) where {iip, spec}
78
check_complete(sys, ODEFunction)
89
check_compatibility && check_compatible_system(ODEFunction, sys)
910

@@ -42,7 +43,7 @@
4243
_M = concrete_massmatrix(M; sparse, u0)
4344

4445
if nlstep
45-
ode_nlstep = generate_ODENLStepData(sys, u0, p, M, nlstep_compile)
46+
ode_nlstep = generate_ODENLStepData(sys, u0, p, M, nlstep_compile, nlstep_scc)
4647
else
4748
ode_nlstep = nothing
4849
end

src/systems/abstractsystem.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,8 @@ const SYS_PROPS = [:eqs
784784
:parent
785785
:is_dde
786786
:tstops
787+
:inputs
788+
:outputs
787789
:index_cache
788790
:isscheduled
789791
:costs
@@ -1820,6 +1822,17 @@ function push_vars!(stmt, name, typ, vars)
18201822
ex = nameof(s)
18211823
end
18221824
push!(vars_expr.args, ex)
1825+
1826+
meta_kvps = Expr[]
1827+
if isinput(s)
1828+
push!(meta_kvps, :(input = true))
1829+
end
1830+
if isoutput(s)
1831+
push!(meta_kvps, :(output = true))
1832+
end
1833+
if !isempty(meta_kvps)
1834+
push!(vars_expr.args, Expr(:vect, meta_kvps...))
1835+
end
18231836
end
18241837
push!(stmt, :($name = $collect($vars_expr)))
18251838
return

src/systems/callbacks.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...)
2525
end
2626
SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...)
2727

28+
function Symbolics.fast_substitute(aff::SymbolicAffect, rules)
29+
substituter = Base.Fix2(fast_substitute, rules)
30+
SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs),
31+
map(substituter, aff.discrete_parameters))
32+
end
33+
2834
struct AffectSystem
2935
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
3036
system::AbstractSystem
@@ -36,6 +42,19 @@ struct AffectSystem
3642
discretes::Vector
3743
end
3844

45+
function Symbolics.fast_substitute(aff::AffectSystem, rules)
46+
substituter = Base.Fix2(fast_substitute, rules)
47+
sys = aff.system
48+
@set! sys.eqs = map(substituter, get_eqs(sys))
49+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
50+
@set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)])
51+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
52+
@set! sys.unknowns = map(substituter, get_unknowns(sys))
53+
@set! sys.ps = map(substituter, get_ps(sys))
54+
AffectSystem(sys, map(substituter, aff.unknowns),
55+
map(substituter, aff.parameters), map(substituter, aff.discretes))
56+
end
57+
3958
function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
4059
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
4160
discrete_parameters = spec.discrete_parameters, kwargs...)
@@ -855,7 +874,7 @@ end
855874
function default_operating_point(affsys::AffectSystem)
856875
sys = system(affsys)
857876

858-
op = Dict(unknowns(sys) .=> 0.0)
877+
op = AnyDict(unknowns(sys) .=> 0.0)
859878
for p in parameters(sys)
860879
T = symtype(p)
861880
if T <: Number

src/systems/connectors.jl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ function generate_connection_equations_and_stream_connections(
759759
var = variable_from_vertex(sys, cvert)::BasicSymbolic
760760
vtype = cvert.type
761761
if vtype <: Union{InputVar, OutputVar}
762+
length(cset) > 1 || continue
762763
inner_output = nothing
763764
outer_input = nothing
764765
for cvert in cset
@@ -780,11 +781,11 @@ function generate_connection_equations_and_stream_connections(
780781
inner_output = cvert
781782
end
782783
end
783-
root, rest = Iterators.peel(cset)
784-
root_var = variable_from_vertex(sys, root)
785-
for cvert in rest
786-
var = variable_from_vertex(sys, cvert)
787-
push!(eqs, root_var ~ var)
784+
root_vert = something(inner_output, outer_input)
785+
root_var = variable_from_vertex(sys, root_vert)
786+
for cvert in cset
787+
isequal(cvert, root_vert) && continue
788+
push!(eqs, variable_from_vertex(sys, cvert) ~ root_var)
788789
end
789790
elseif vtype === Stream
790791
push!(stream_connections, cset)
@@ -807,10 +808,37 @@ function generate_connection_equations_and_stream_connections(
807808
push!(eqs, 0 ~ rhs)
808809
end
809810
else # Equality
810-
base = variable_from_vertex(sys, cset[1])
811-
for i in 2:length(cset)
812-
v = variable_from_vertex(sys, cset[i])
813-
push!(eqs, base ~ v)
811+
vars = map(Base.Fix1(variable_from_vertex, sys), cset)
812+
outer_input = inner_output = nothing
813+
all_io = true
814+
# attempt to interpret the equality as a causal connectionset if
815+
# possible
816+
for (cvert, vert) in zip(cset, vars)
817+
is_i = isinput(vert)
818+
is_o = isoutput(vert)
819+
all_io &= is_i || is_o
820+
all_io || break
821+
if cvert.isouter && is_i && outer_input === nothing
822+
outer_input = cvert
823+
elseif !cvert.isouter && is_o && inner_output === nothing
824+
inner_output = cvert
825+
end
826+
end
827+
# this doesn't necessarily mean this is a well-structured causal connection,
828+
# but it is sufficient and we're generating equalities anyway.
829+
if all_io && xor(outer_input !== nothing, inner_output !== nothing)
830+
root_vert = something(inner_output, outer_input)
831+
root_var = variable_from_vertex(sys, root_vert)
832+
for (cvert, var) in zip(cset, vars)
833+
isequal(cvert, root_vert) && continue
834+
push!(eqs, var ~ root_var)
835+
end
836+
else
837+
base = variable_from_vertex(sys, cset[1])
838+
for i in 2:length(cset)
839+
v = vars[i]
840+
push!(eqs, base ~ v)
841+
end
814842
end
815843
end
816844
end

src/systems/diffeqs/basic_transformations.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,143 @@ function convert_system_indepvar(sys::System, t; name = nameof(sys))
706706
@set! sys.var_to_name = var_to_name
707707
return sys
708708
end
709+
710+
"""
711+
$(TYPEDSIGNATURES)
712+
713+
Shorthand for `respecialize(sys, []; all = true)`
714+
"""
715+
respecialize(sys::AbstractSystem) = respecialize(sys, []; all = true)
716+
717+
"""
718+
$(TYPEDSIGNATURES)
719+
720+
Specialize nonnumeric parameters in `sys` by changing their symtype to a concrete type.
721+
`mapping` is an iterable, where each element can be a parameter or a pair mapping a parameter
722+
to a value. If the element is a parameter, it must have a default. Each specified parameter
723+
is updated to have the symtype of the value associated with it (either in `mapping` or in
724+
the defaults). This operation can only be performed on nonnumeric, non-array parameters. The
725+
defaults of respecialized parameters are set to the associated values.
726+
727+
This operation can only be performed on `complete`d systems.
728+
729+
# Keyword arguments
730+
731+
- `all`: Specialize all nonnumeric parameters in the system. This will error if any such
732+
parameter does not have a default.
733+
"""
734+
function respecialize(sys::AbstractSystem, mapping; all = false)
735+
if !iscomplete(sys)
736+
error("""
737+
This operation can only be performed on completed systems. Use `complete(sys)` or
738+
`mtkcompile(sys)`.
739+
""")
740+
end
741+
if !is_split(sys)
742+
error("""
743+
This operation can only be performed on split systems. Use `complete(sys)` or
744+
`mtkcompile(sys)` with the `split = true` keyword argument.
745+
""")
746+
end
747+
748+
new_ps = copy(get_ps(sys))
749+
@set! sys.ps = new_ps
750+
751+
extras = []
752+
if all
753+
for x in filter(!is_variable_numeric, get_ps(sys))
754+
if any(y -> isequal(x, y) || y isa Pair && isequal(x, y[1]), mapping) ||
755+
symbolic_type(x) === ArraySymbolic() ||
756+
iscall(x) && operation(x) === getindex
757+
continue
758+
end
759+
push!(extras, x)
760+
end
761+
end
762+
ps_to_specialize = Iterators.flatten((extras, mapping))
763+
764+
defs = copy(defaults(sys))
765+
@set! sys.defaults = defs
766+
final_defs = copy(defs)
767+
evaluate_varmap!(final_defs, ps_to_specialize)
768+
769+
subrules = Dict()
770+
771+
for element in ps_to_specialize
772+
if element isa Pair
773+
k, v = element
774+
else
775+
k = element
776+
v = get(final_defs, k, nothing)
777+
@assert v !== nothing """
778+
Parameter $k needs an associated value to be respecialized.
779+
"""
780+
@assert symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) """
781+
Parameter $k needs an associated value to be respecialized. Found symbolic \
782+
default $v.
783+
"""
784+
end
785+
786+
k = unwrap(k)
787+
T = typeof(v)
788+
789+
@assert !is_variable_numeric(k) """
790+
Numeric types cannot be respecialized - tried to respecialize $k.
791+
"""
792+
@assert symbolic_type(k) !== ArraySymbolic() """
793+
Cannot respecialize array symbolics - tried to respecialize $k.
794+
"""
795+
@assert !iscall(k) || operation(k) !== getindex """
796+
Cannot respecialized scalarized array variables - tried to respecialize $k.
797+
"""
798+
idx = findfirst(isequal(k), get_ps(sys))
799+
@assert idx !== nothing """
800+
Parameter $k does not exist in the system.
801+
"""
802+
803+
if iscall(k)
804+
op = operation(k)
805+
args = arguments(k)
806+
new_p = SymbolicUtils.term(op, args...; type = T)
807+
else
808+
new_p = SymbolicUtils.Sym{T}(getname(k))
809+
end
810+
811+
get_ps(sys)[idx] = new_p
812+
defaults(sys)[new_p] = v
813+
subrules[unwrap(k)] = unwrap(new_p)
814+
end
815+
816+
substituter = Base.Fix2(fast_substitute, subrules)
817+
@set! sys.eqs = map(substituter, get_eqs(sys))
818+
@set! sys.observed = map(substituter, get_observed(sys))
819+
@set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys))
820+
if get_noise_eqs(sys) !== nothing
821+
@set! sys.noise_eqs = map(substituter, get_noise_eqs(sys))
822+
end
823+
@set! sys.assertions = Dict([substituter(k) => v for (k, v) in assertions(sys)])
824+
@set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys))
825+
@set! sys.defaults = Dict([substituter(k) => substituter(v) for (k, v) in defaults(sys)])
826+
@set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)])
827+
@set! sys.continuous_events = map(get_continuous_events(sys)) do cev
828+
SymbolicContinuousCallback(
829+
map(substituter, cev.conditions), substituter(cev.affect),
830+
substituter(cev.affect_neg), substituter(cev.initialize),
831+
substituter(cev.finalize), cev.rootfind,
832+
cev.reinitializealg, cev.zero_crossing_id)
833+
end
834+
@set! sys.discrete_events = map(get_discrete_events(sys)) do dev
835+
SymbolicDiscreteCallback(map(substituter, dev.conditions), substituter(dev.affect),
836+
substituter(dev.initialize), substituter(dev.finalize), dev.reinitializealg)
837+
end
838+
if get_schedule(sys) !== nothing
839+
sched = get_schedule(sys)
840+
@set! sys.schedule = Schedule(
841+
sched.var_sccs, AnyDict(k => substituter(v) for (k, v) in sched.dummy_sub))
842+
end
843+
@set! sys.constraints = map(substituter, get_constraints(sys))
844+
@set! sys.tstops = map(substituter, get_tstops(sys))
845+
@set! sys.costs = Vector{Union{Real, BasicSymbolic}}(map(substituter, get_costs(sys)))
846+
sys = complete(sys; split = is_split(sys))
847+
return sys
848+
end

0 commit comments

Comments
 (0)