Skip to content

Commit bb0fa0a

Browse files
authored
Merge pull request #1675 from SciML/myb/linearize
WIP: Add linearize function
2 parents 4048ae8 + ddd09dc commit bb0fa0a

File tree

7 files changed

+497
-15
lines changed

7 files changed

+497
-15
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1616
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1717
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1818
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
19+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1920
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
2021
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
2122
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -54,6 +55,7 @@ DiffRules = "0.1, 1.0"
5455
Distributions = "0.23, 0.24, 0.25"
5556
DocStringExtensions = "0.7, 0.8, 0.9"
5657
DomainSets = "0.5"
58+
ForwardDiff = "0.10.3"
5759
Graphs = "1.5.2"
5860
IfElse = "0.1"
5961
JuliaFormatter = "1"
@@ -82,6 +84,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
8284
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8385
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
8486
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
87+
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
8588
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
8689
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
8790
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
@@ -95,4 +98,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
9598
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9699

97100
[targets]
98-
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
101+
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/ModelingToolkit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ $(DocStringExtensions.README)
44
module ModelingToolkit
55
using DocStringExtensions
66
using AbstractTrees
7-
using DiffEqBase, SciMLBase, Reexport
7+
using DiffEqBase, SciMLBase, ForwardDiff, Reexport
88
using Distributed
99
using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
1010
using InteractiveUtils
@@ -181,7 +181,7 @@ export Term, Sym
181181
export SymScope, LocalScope, ParentScope, GlobalScope
182182
export independent_variables, independent_variable, states, parameters, equations, controls,
183183
observed, structure, full_equations
184-
export structural_simplify, expand_connections
184+
export structural_simplify, expand_connections, linearize, linear_statespace
185185
export DiscreteSystem, DiscreteProblem
186186

187187
export calculate_jacobian, generate_jacobian, generate_function

src/inputoutput.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,20 @@ unbound_outputs(sys) = filter(x -> !is_bound(sys, x), outputs(sys))
6161
Determine whether or not input/output variable `u` is "bound" within the system, i.e., if it's to be considered internal to `sys`.
6262
A variable/signal is considered bound if it appears in an equation together with variables from other subsystems.
6363
The typical usecase for this function is to determine whether the input to an IO component is connected to another component,
64-
or if it remains an external input that the user has to supply before simulating the system.
64+
or if it remains an external input that the user has to supply before simulating the system.
6565
6666
See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@ref), [`unbound_outputs`](@ref)
6767
"""
6868
function is_bound(sys, u, stack = [])
6969
#=
70-
For observed quantities, we check if a variable is connected to something that is bound to something further out.
70+
For observed quantities, we check if a variable is connected to something that is bound to something further out.
7171
In the following scenario
7272
julia> observed(syss)
7373
2-element Vector{Equation}:
7474
sys₊y(tv) ~ sys₊x(tv)
7575
y(tv) ~ sys₊x(tv)
7676
sys₊y(t) is bound to the outer y(t) through the variable sys₊x(t) and should thus return is_bound(sys₊y(t)) = true.
77-
When asking is_bound(sys₊y(t)), we know that we are looking through observed equations and can thus ask
77+
When asking is_bound(sys₊y(t)), we know that we are looking through observed equations and can thus ask
7878
if var is bound, if it is, then sys₊y(t) is also bound. This can lead to an infinite recursion, so we maintain a stack of variables we have previously asked about to be able to break cycles
7979
=#
8080
u Set(stack) && return false # Cycle detected
@@ -241,7 +241,7 @@ function toparam(sys, ctrls::AbstractVector)
241241
ODESystem(eqs, name = nameof(sys))
242242
end
243243

244-
function inputs_to_parameters!(state::TransformationState)
244+
function inputs_to_parameters!(state::TransformationState, check_bound = true)
245245
@unpack structure, fullvars, sys = state
246246
@unpack var_to_diff, graph, solvable_graph = structure
247247
@assert solvable_graph === nothing
@@ -254,7 +254,7 @@ function inputs_to_parameters!(state::TransformationState)
254254
input_to_parameters = Dict()
255255
new_fullvars = []
256256
for (i, v) in enumerate(fullvars)
257-
if isinput(v) && !is_bound(sys, v)
257+
if isinput(v) && !(check_bound && is_bound(sys, v))
258258
if var_to_diff[i] !== nothing
259259
error("Input $(fullvars[i]) is differentiated!")
260260
end
@@ -270,7 +270,7 @@ function inputs_to_parameters!(state::TransformationState)
270270
push!(new_fullvars, v)
271271
end
272272
end
273-
ninputs == 0 && return state
273+
ninputs == 0 && return (state, 1:0)
274274

275275
nvars = ndsts(graph) - ninputs
276276
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
@@ -296,9 +296,12 @@ function inputs_to_parameters!(state::TransformationState)
296296

297297
@set! sys.eqs = map(Base.Fix2(substitute, input_to_parameters), equations(sys))
298298
@set! sys.states = setdiff(states(sys), keys(input_to_parameters))
299-
@set! sys.ps = [parameters(sys); new_parameters]
299+
ps = parameters(sys)
300+
@set! sys.ps = [ps; new_parameters]
300301

301302
@set! state.sys = sys
302303
@set! state.fullvars = new_fullvars
303304
@set! state.structure = structure
305+
base_params = length(ps)
306+
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
304307
end

0 commit comments

Comments
 (0)