Skip to content

Commit 4414ca8

Browse files
committed
Adding controls kwarg and controls jacobian to AbstractODESystem
1 parent d7b8f96 commit 4414ca8

File tree

4 files changed

+77
-3
lines changed

4 files changed

+77
-3
lines changed

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,13 @@ export Differential, expand_derivatives, @derivatives
161161
export Equation, ConstrainedEquation
162162
export Term, Sym
163163
export SymScope, LocalScope, ParentScope, GlobalScope
164-
export independent_variable, states, parameters, equations, controls, observed, structure
164+
export independent_variable, states, parameters, equations, controls, observed, structure, defaults
165+
export ssmodel, linearize
165166
export structural_simplify
166167
export DiscreteSystem, DiscreteProblem
167168

168169
export calculate_jacobian, generate_jacobian, generate_function
170+
export calculate_control_jacobian
169171
export calculate_tgrad, generate_tgrad
170172
export calculate_gradient, generate_gradient
171173
export calculate_factorized_W, generate_factorized_W

src/systems/abstractsystem.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ call will be cached in the system object.
3434
"""
3535
function calculate_jacobian end
3636

37+
"""
38+
```julia
39+
calculate_control_jacobian(sys::AbstractSystem)
40+
```
41+
42+
Calculate the jacobian matrix of a system with respect to the system's controls.
43+
44+
Returns a matrix of [`Num`](@ref) instances. The result from the first
45+
call will be cached in the system object.
46+
"""
47+
function calculate_control_jacobian end
48+
3749
"""
3850
```julia
3951
calculate_factorized_W(sys::AbstractSystem)
@@ -140,10 +152,12 @@ for prop in [
140152
:iv
141153
:states
142154
:ps
155+
:ctrl
143156
:defaults
144157
:observed
145158
:tgrad
146159
:jac
160+
:ctrl_jac
147161
:Wfact
148162
:Wfact_t
149163
:systems
@@ -346,11 +360,17 @@ function states(sys::AbstractSystem)
346360
sts :
347361
[sts;reduce(vcat,namespace_variables.(systems))])
348362
end
363+
349364
function parameters(sys::AbstractSystem)
350365
ps = get_ps(sys)
351366
systems = get_systems(sys)
352367
isempty(systems) ? ps : [ps;reduce(vcat,namespace_parameters.(systems))]
353368
end
369+
370+
function controls(sys::AbstractSystem)
371+
get_ctrl(sys)
372+
end
373+
354374
function observed(sys::AbstractSystem)
355375
iv = independent_variable(sys)
356376
obs = get_observed(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,28 @@ function calculate_jacobian(sys::AbstractODESystem;
3838
return jac
3939
end
4040

41+
function calculate_control_jacobian(sys::AbstractODESystem;
42+
sparse=false, simplify=false)
43+
cache = get_ctrl_jac(sys)[]
44+
if cache isa Tuple && cache[2] == (sparse, simplify)
45+
return cache[1]
46+
end
47+
48+
rhs = [eq.rhs for eq equations(sys)]
49+
50+
iv = get_iv(sys)
51+
ctrls = controls(sys)
52+
53+
if sparse
54+
jac = sparsejacobian(rhs, ctrls, simplify=simplify)
55+
else
56+
jac = jacobian(rhs, ctrls, simplify=simplify)
57+
end
58+
59+
get_ctrl_jac(sys)[] = jac, (sparse, simplify) # cache Jacobian
60+
return jac
61+
end
62+
4163
function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
4264
simplify=false, kwargs...)
4365
tgrad = calculate_tgrad(sys,simplify=simplify)
@@ -50,6 +72,25 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
5072
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
5173
end
5274

75+
"""
76+
```julia
77+
generate_linearization(sys::ODESystem, dvs = states(sys), ps = parameters(sys), ctrls = controls(sys), point = defaults(sys), expression = Val{true}; sparse = false, kwargs...)
78+
```
79+
80+
Generates a function for the linearized state space model of the system. Extra arguments
81+
control the arguments to the internal [`build_function`](@ref) call.
82+
"""
83+
function generate_linearization(sys::AbstractSystem, dvs = states(sys), ps = parameters(sys), ctrls = controls(sys);
84+
simplify=false, sparse=false, kwargs...)
85+
ops = map(eq -> eq.rhs, equations(sys))
86+
87+
jac = calculate_jacobian(sys;simplify=simplify,sparse=sparse)
88+
89+
A = @views J[1:length(states(sys)), 1:length(states(sys))]
90+
B = @views J[1:length(states(sys)), length(states(sys))+1:end]
91+
92+
return A, B
93+
end
5394
@noinline function throw_invalid_derivative(dervar, eq)
5495
msg = "The derivative variable must be isolated to the left-hand " *
5596
"side of the equation like `$dervar ~ ...`.\n Got $eq."

src/systems/diffeqs/odesystem.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ struct ODESystem <: AbstractODESystem
3131
states::Vector
3232
"""Parameter variables. Must not contain the independent variable."""
3333
ps::Vector
34+
ctrl::Vector
3435
observed::Vector{Equation}
3536
"""
3637
Time-derivative matrix. Note: this field will not be defined until
@@ -43,6 +44,11 @@ struct ODESystem <: AbstractODESystem
4344
"""
4445
jac::RefValue{Any}
4546
"""
47+
Control Jacobian matrix. Note: this field will not be defined until
48+
[`calculate_control_jacobian`](@ref) is called on the system.
49+
"""
50+
ctrl_jac::RefValue{Any}
51+
"""
4652
`Wfact` matrix. Note: this field will not be defined until
4753
[`generate_factorized_W`](@ref) is called on the system.
4854
"""
@@ -84,6 +90,7 @@ end
8490

8591
function ODESystem(
8692
deqs::AbstractVector{<:Equation}, iv, dvs, ps;
93+
controls = Num[],
8794
observed = Num[],
8895
systems = ODESystem[],
8996
name=gensym(:ODESystem),
@@ -92,9 +99,13 @@ function ODESystem(
9299
defaults=_merge(Dict(default_u0), Dict(default_p)),
93100
connection_type=nothing,
94101
)
102+
103+
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
104+
95105
iv′ = value(scalarize(iv))
96106
dvs′ = value.(scalarize(dvs))
97107
ps′ = value.(scalarize(ps))
108+
ctrl′ = value.(scalarize(controls))
98109

99110
if !(isempty(default_u0) && isempty(default_p))
100111
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ODESystem, force=true)
@@ -110,7 +121,7 @@ function ODESystem(
110121
if length(unique(sysnames)) != length(sysnames)
111122
throw(ArgumentError("System names must be unique."))
112123
end
113-
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
124+
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
114125
end
115126

116127
vars(x::Sym) = Set([x])
@@ -349,4 +360,4 @@ function convert_system(::Type{<:ODESystem}, sys, t; name=nameof(sys))
349360
neweqs = map(sub, equations(sys))
350361
defs = Dict(sub(k) => sub(v) for (k, v) in defaults(sys))
351362
return ODESystem(neweqs, t, newsts, parameters(sys); defaults=defs, name=name)
352-
end
363+
end

0 commit comments

Comments
 (0)