Skip to content

Commit e2c7304

Browse files
committed
Adding generate_control_jacobian function
1 parent 0abd818 commit e2c7304

File tree

5 files changed

+18
-8
lines changed

5 files changed

+18
-8
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ export structural_simplify
167167
export DiscreteSystem, DiscreteProblem
168168

169169
export calculate_jacobian, generate_jacobian, generate_function
170-
export calculate_control_jacobian
170+
export calculate_control_jacobian, generate_control_jacobian
171171
export calculate_tgrad, generate_tgrad
172172
export calculate_gradient, generate_gradient
173173
export calculate_factorized_W, generate_factorized_W

src/systems/abstractsystem.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ for prop in [
152152
:iv
153153
:states
154154
:ps
155-
:ctrl
155+
:ctrls
156156
:defaults
157157
:observed
158158
:tgrad
@@ -315,6 +315,7 @@ end
315315

316316
namespace_variables(sys::AbstractSystem) = states(sys, states(sys))
317317
namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
318+
namespace_controls(sys::AbstractSystem) = controls(sys, controls(sys))
318319

319320
function namespace_defaults(sys)
320321
defs = defaults(sys)
@@ -358,17 +359,19 @@ function states(sys::AbstractSystem)
358359
systems = get_systems(sys)
359360
unique(isempty(systems) ?
360361
sts :
361-
[sts;reduce(vcat,namespace_variables.(systems))])
362+
[sts; reduce(vcat,namespace_variables.(systems))])
362363
end
363364

364365
function parameters(sys::AbstractSystem)
365366
ps = get_ps(sys)
366367
systems = get_systems(sys)
367-
isempty(systems) ? ps : [ps;reduce(vcat,namespace_parameters.(systems))]
368+
isempty(systems) ? ps : [ps; reduce(vcat,namespace_parameters.(systems))]
368369
end
369370

370371
function controls(sys::AbstractSystem)
371-
get_ctrl(flatten(sys))
372+
ctrls = get_ctrls(sys)
373+
systems = get_systems(sys)
374+
isempty(systems) ? ctrls : [ctrls; reduce(vcat,namespace_controls.(systems))]
372375
end
373376

374377
function observed(sys::AbstractSystem)

src/systems/control/controlsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
abstract type AbstractControlSystem <: AbstractSystem end
22

3-
function namespace_controls(sys::AbstractSystem)
3+
function namespace_controls(sys::AbstractControlSystem)
44
[rename(x,renamespace(nameof(sys),nameof(x))) for x in controls(sys)]
55
end
66

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
7272
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
7373
end
7474

75+
function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
76+
simplify=false, sparse = false, kwargs...)
77+
jac = calculate_control_jacobian(sys;simplify=simplify,sparse=sparse)
78+
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
79+
end
80+
7581
"""
7682
```julia
7783
generate_linearization(sys::ODESystem, dvs = states(sys), ps = parameters(sys), ctrls = controls(sys), point = defaults(sys), expression = Val{true}; sparse = false, kwargs...)

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct ODESystem <: AbstractODESystem
3131
states::Vector
3232
"""Parameter variables. Must not contain the independent variable."""
3333
ps::Vector
34-
ctrl::Vector
34+
ctrls::Vector
3535
observed::Vector{Equation}
3636
"""
3737
Time-derivative matrix. Note: this field will not be defined until
@@ -115,13 +115,14 @@ function ODESystem(
115115

116116
tgrad = RefValue(Vector{Num}(undef, 0))
117117
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
118+
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
118119
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
119120
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
120121
sysnames = nameof.(systems)
121122
if length(unique(sysnames)) != length(sysnames)
122123
throw(ArgumentError("System names must be unique."))
123124
end
124-
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
125+
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
125126
end
126127

127128
vars(x::Sym) = Set([x])

0 commit comments

Comments
 (0)