Skip to content

Commit d8cd188

Browse files
YingboMabaggepinnen
andcommitted
Add linearize function
Co-authored-by: Fredrik Bagge Carlson <[email protected]>
1 parent 942a24d commit d8cd188

File tree

6 files changed

+182
-12
lines changed

6 files changed

+182
-12
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1717
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1818
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
20+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2021
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
2122
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
2223
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -55,6 +56,7 @@ DiffRules = "0.1, 1.0"
5556
Distributions = "0.23, 0.24, 0.25"
5657
DocStringExtensions = "0.7, 0.8, 0.9"
5758
DomainSets = "0.5"
59+
ForwardDiff = "0.10.3"
5860
Graphs = "1.5.2"
5961
IfElse = "0.1"
6062
JuliaFormatter = "1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ $(DocStringExtensions.README)
44
module ModelingToolkit
55
using DocStringExtensions
66
using AbstractTrees
7-
using DiffEqBase, SciMLBase, Reexport
7+
using DiffEqBase, SciMLBase, ForwardDiff, Reexport
88
using Distributed
99
using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
1010
using InteractiveUtils

src/inputoutput.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,20 @@ unbound_outputs(sys) = filter(x -> !is_bound(sys, x), outputs(sys))
6161
Determine whether or not input/output variable `u` is "bound" within the system, i.e., if it's to be considered internal to `sys`.
6262
A variable/signal is considered bound if it appears in an equation together with variables from other subsystems.
6363
The typical usecase for this function is to determine whether the input to an IO component is connected to another component,
64-
or if it remains an external input that the user has to supply before simulating the system.
64+
or if it remains an external input that the user has to supply before simulating the system.
6565
6666
See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@ref), [`unbound_outputs`](@ref)
6767
"""
6868
function is_bound(sys, u, stack = [])
6969
#=
70-
For observed quantities, we check if a variable is connected to something that is bound to something further out.
70+
For observed quantities, we check if a variable is connected to something that is bound to something further out.
7171
In the following scenario
7272
julia> observed(syss)
7373
2-element Vector{Equation}:
7474
sys₊y(tv) ~ sys₊x(tv)
7575
y(tv) ~ sys₊x(tv)
7676
sys₊y(t) is bound to the outer y(t) through the variable sys₊x(t) and should thus return is_bound(sys₊y(t)) = true.
77-
When asking is_bound(sys₊y(t)), we know that we are looking through observed equations and can thus ask
77+
When asking is_bound(sys₊y(t)), we know that we are looking through observed equations and can thus ask
7878
if var is bound, if it is, then sys₊y(t) is also bound. This can lead to an infinite recursion, so we maintain a stack of variables we have previously asked about to be able to break cycles
7979
=#
8080
u Set(stack) && return false # Cycle detected
@@ -241,7 +241,7 @@ function toparam(sys, ctrls::AbstractVector)
241241
ODESystem(eqs, name = nameof(sys))
242242
end
243243

244-
function inputs_to_parameters!(state::TransformationState)
244+
function inputs_to_parameters!(state::TransformationState, check_bound = true)
245245
@unpack structure, fullvars, sys = state
246246
@unpack var_to_diff, graph, solvable_graph = structure
247247
@assert solvable_graph === nothing
@@ -254,7 +254,7 @@ function inputs_to_parameters!(state::TransformationState)
254254
input_to_parameters = Dict()
255255
new_fullvars = []
256256
for (i, v) in enumerate(fullvars)
257-
if isinput(v) && !is_bound(sys, v)
257+
if isinput(v) && !(check_bound && is_bound(sys, v))
258258
if var_to_diff[i] !== nothing
259259
error("Input $(fullvars[i]) is differentiated!")
260260
end
@@ -296,9 +296,12 @@ function inputs_to_parameters!(state::TransformationState)
296296

297297
@set! sys.eqs = map(Base.Fix2(substitute, input_to_parameters), equations(sys))
298298
@set! sys.states = setdiff(states(sys), keys(input_to_parameters))
299-
@set! sys.ps = [parameters(sys); new_parameters]
299+
ps = parameters(sys)
300+
@set! sys.ps = [ps; new_parameters]
300301

301302
@set! state.sys = sys
302303
@set! state.fullvars = new_fullvars
303304
@set! state.structure = structure
305+
base_params = length(ps)
306+
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
304307
end

src/systems/abstractsystem.jl

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -950,10 +950,10 @@ types during tearing.
950950
"""
951951
function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
952952
sys = expand_connections(sys)
953-
sys = alias_elimination(sys)
954953
state = TearingState(sys)
955-
state = inputs_to_parameters!(state)
956-
sys = state.sys
954+
state, = inputs_to_parameters!(state)
955+
sys = alias_elimination!(state)
956+
state = TearingState(sys)
957957
check_consistency(state)
958958
if sys isa ODESystem
959959
sys = dae_order_lowering(dummy_derivative(sys, state))
@@ -967,6 +967,87 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
967967
return sys
968968
end
969969

970+
export linearize
971+
# TODO: The order of the states and equations should match so that the Jacobians are ẋ = Ax
972+
function linearize(sys::AbstractSystem, inputs, outputs; simplify = false, kwargs...)
973+
sys = expand_connections(sys)
974+
state = TearingState(sys)
975+
markio!(state, inputs, outputs)
976+
state, input_idxs = inputs_to_parameters!(state, false)
977+
sys = alias_elimination!(state)
978+
state = TearingState(sys)
979+
check_consistency(state)
980+
if sys isa ODESystem
981+
sys = dae_order_lowering(dummy_derivative(sys, state))
982+
end
983+
state = TearingState(sys)
984+
find_solvables!(state; kwargs...)
985+
sys = tearing_reassemble(state, tearing(state), simplify = simplify)
986+
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
987+
@set! sys.observed = topsort_equations(observed(sys), fullstates)
988+
invalidate_cache!(sys)
989+
990+
eqs = equations(sys)
991+
check_operator_variables(eqs, Differential)
992+
# Sort equations and states such that diff.eqs. match differential states and the rest are algebraic
993+
diffstates = collect_operator_variables(sys, Differential)
994+
eqs = sort(eqs, by = e -> !isoperator(e.lhs, Differential),
995+
alg = Base.Sort.DEFAULT_STABLE)
996+
@set! sys.eqs = eqs
997+
diffstates = [arguments(e.lhs)[1] for e in eqs[1:length(diffstates)]]
998+
sts = [diffstates; setdiff(states(sys), diffstates)]
999+
@set! sys.states = sts
1000+
1001+
diff_idxs = 1:length(diffstates)
1002+
alge_idxs = (length(diffstates) + 1):length(sts)
1003+
fun = ODEFunction(sys)
1004+
lin_fun = let fun = fun,
1005+
h = ModelingToolkit.build_explicit_observed_function(sys, outputs)
1006+
1007+
(u, p, t) -> begin
1008+
uf = SciMLBase.UJacobianWrapper(fun, t, p)
1009+
fg_xz = ForwardDiff.jacobian(uf, u)
1010+
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
1011+
# TODO: this is very inefficient, p contains all parameters of the system
1012+
fg_u = ForwardDiff.jacobian(pf, p)[:, input_idxs]
1013+
h_xz = ForwardDiff.jacobian(xz -> h(xz, p, t), u)
1014+
h_u = ForwardDiff.jacobian(p -> h(u, p, t), p)[:, input_idxs]
1015+
(f_x = fg_xz[diff_idxs, diff_idxs],
1016+
f_z = fg_xz[diff_idxs, alge_idxs],
1017+
g_x = fg_xz[alge_idxs, diff_idxs],
1018+
g_z = fg_xz[alge_idxs, alge_idxs],
1019+
f_u = fg_u[diff_idxs, :],
1020+
g_u = fg_u[alge_idxs, :],
1021+
h_x = h_xz[:, diff_idxs],
1022+
h_z = h_xz[:, alge_idxs],
1023+
h_u = h_u)
1024+
end
1025+
end
1026+
return sys, lin_fun
1027+
end
1028+
1029+
function markio!(state::TearingState, inputs, outputs)
1030+
fullvars = state.fullvars
1031+
inputset = Set(inputs)
1032+
outputset = Set(outputs)
1033+
for (i, v) in enumerate(fullvars)
1034+
if v in inputset
1035+
v = setmetadata(v, ModelingToolkit.VariableInput, true)
1036+
v = setmetadata(v, ModelingToolkit.VariableOutput, false)
1037+
fullvars[i] = v
1038+
elseif v in outputset
1039+
v = setmetadata(v, ModelingToolkit.VariableInput, false)
1040+
v = setmetadata(v, ModelingToolkit.VariableOutput, true)
1041+
fullvars[i] = v
1042+
else
1043+
v = setmetadata(v, ModelingToolkit.VariableInput, false)
1044+
v = setmetadata(v, ModelingToolkit.VariableOutput, false)
1045+
fullvars[i] = v
1046+
end
1047+
end
1048+
state
1049+
end
1050+
9701051
@latexrecipe function f(sys::AbstractSystem)
9711052
return latexify(equations(sys))
9721053
end

src/systems/alias_elimination.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ function aag_bareiss(sys::AbstractSystem)
1919
return aag_bareiss!(state.structure.graph, complete(state.structure.var_to_diff), mm)
2020
end
2121

22-
function alias_elimination(sys)
23-
state = TearingState(sys; quick_cancel = true)
22+
alias_elimination(sys) = alias_elimination!(TearingState(sys; quick_cancel = true))
23+
function alias_elimination!(state::TearingState)
24+
sys = state.sys
2425
ag, mm = alias_eliminate_graph!(state)
2526
ag === nothing && return sys
2627

test/linearize.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
using ModelingToolkit
2+
3+
# r is an input, and y is an output.
4+
@variables t x(t)=0 y(t)=0 u(t)=0 r(t)=0
5+
@variables t x(t)=0 y(t)=0 u(t)=0 r(t)=0 [input = true]
6+
@parameters kp = 1
7+
D = Differential(t)
8+
9+
eqs = [u ~ kp * (r - y)
10+
D(x) ~ -x + u
11+
y ~ x]
12+
13+
@named sys = ODESystem(eqs, t)
14+
linearize(sys, [r], [y])
15+
16+
##
17+
```
18+
19+
r ┌─────┐ ┌─────┐ ┌─────┐
20+
───►│ ├──────►│ │ u │ │
21+
│ F │ │ C ├────►│ P │ y
22+
└─────┘ ┌►│ │ │ ├─┬─►
23+
│ └─────┘ └─────┘ │
24+
│ │
25+
└─────────────────────┘
26+
```
27+
28+
function plant(; name)
29+
@variables x(t) = 1
30+
@variables u(t)=0 [input = true] y(t)=0 [output = true]
31+
D = Differential(t)
32+
eqs = [D(x) ~ -x + u
33+
y ~ x]
34+
ODESystem(eqs, t; name = name)
35+
end
36+
37+
function filt_(; name)
38+
@variables x(t)=0 y(t)=0 [output = true]
39+
@variables u(t)=0 [input = true]
40+
D = Differential(t)
41+
eqs = [D(x) ~ -2 * x + u
42+
y ~ x]
43+
ODESystem(eqs, t, name = name)
44+
end
45+
46+
function controller(kp; name)
47+
@variables y(t)=0 r(t)=0 [input = true] u(t)=0
48+
@parameters kp = kp
49+
eqs = [
50+
u ~ kp * (r - y),
51+
]
52+
ODESystem(eqs, t; name = name)
53+
end
54+
55+
@named f = filt_()
56+
@named c = controller(1)
57+
@named p = plant()
58+
59+
connections = [f.y ~ c.r # filtered reference to controller reference
60+
c.u ~ p.u # controller output to plant input
61+
p.y ~ c.y]
62+
63+
@named cl = ODESystem(connections, t, systems = [f, c, p])
64+
65+
lin, xs = linearize(cl, cl.f.u, cl.p.x)
66+
67+
##
68+
using ModelingToolkitStandardLibrary.Blocks: LimPID
69+
#using ControlSystems
70+
k = 400;
71+
Ti = 0.5;
72+
Td = 1;
73+
Nd = 10;
74+
#s = tf("s")
75+
#expected_result_r = k*(1 + 1/(s*Ti)) |> ss
76+
#expected_result_y = k*(1 + 1/(s*Ti) - s*Td / (1 + s*Td/N)) |> ss
77+
@named pid = LimPID(; k, Ti, Td, Nd)
78+
ModelingToolkit.unbound_inputs(pid)
79+
80+
@unpack reference, measurement, ctr_output = pid
81+
lin = linearize(pid, [reference.u, measurement.u], [ctr_output.u])
82+
lin, lin_fun = linearize(pid, [reference.u, measurement.u], [ctr_output.u]);
83+
lin_fun(prob.u0, prob.p, 0.0)

0 commit comments

Comments
 (0)