Skip to content

Commit 1ff060a

Browse files
refactor: modernize ODEInputFunction
1 parent 1897db8 commit 1ff060a

File tree

1 file changed

+13
-46
lines changed

1 file changed

+13
-46
lines changed

src/systems/optimal_control_interface.jl

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,10 @@ end
4747

4848
is_explicit(tableau) = tableau isa DiffEqBase.ExplicitRKTableau
4949

50-
"""
51-
Generate the control function f(x, u, p, t) from the ODESystem.
52-
Input variables are automatically inferred but can be manually specified.
53-
"""
54-
function SciMLBase.ODEInputFunction{iip, specialize}(sys::System,
55-
dvs = unknowns(sys),
56-
ps = parameters(sys), u0 = nothing,
50+
@fallback_iip_specialize function SciMLBase.ODEInputFunction{iip, specialize}(sys::System;
5751
inputs = unbound_inputs(sys),
58-
disturbance_inputs = disturbances(sys);
59-
version = nothing, tgrad = false,
52+
disturbance_inputs = disturbances(sys),
53+
u0 = nothing, tgrad = false,
6054
jac = false, controljac = false,
6155
p = nothing, t = nothing,
6256
eval_expression = false,
@@ -66,7 +60,6 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::System,
6660
checkbounds = false,
6761
sparsity = false,
6862
analytic = nothing,
69-
split_idxs = nothing,
7063
initialization_data = nothing,
7164
cse = true,
7265
kwargs...) where {iip, specialize}
@@ -75,61 +68,49 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::System,
7568
f = f[1]
7669

7770
if tgrad
78-
tgrad_gen = generate_tgrad(sys, dvs, ps;
71+
_tgrad = generate_tgrad(sys;
7972
simplify = simplify,
8073
expression = Val{true},
74+
wrap_gfw = Val{true},
8175
expression_module = eval_module, cse,
8276
checkbounds = checkbounds, kwargs...)
83-
tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module)
84-
_tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip)
8577
else
8678
_tgrad = nothing
8779
end
8880

8981
if jac
90-
jac_gen = generate_jacobian(sys, dvs, ps;
82+
_jac = generate_jacobian(sys;
9183
simplify = simplify, sparse = sparse,
9284
expression = Val{true},
85+
wrap_gfw = Val{true},
9386
expression_module = eval_module, cse,
9487
checkbounds = checkbounds, kwargs...)
95-
jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module)
96-
97-
_jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip)
9888
else
9989
_jac = nothing
10090
end
10191

10292
if controljac
103-
cjac_gen = generate_control_jacobian(sys, dvs, ps;
93+
_cjac = generate_control_jacobian(sys;
10494
simplify = simplify, sparse = sparse,
105-
expression = Val{true},
95+
expression = Val{true}, wrap_gfw = Val{true},
10696
expression_module = eval_module, cse,
10797
checkbounds = checkbounds, kwargs...)
108-
cjac_oop, cjac_iip = eval_or_rgf.(cjac_gen; eval_expression, eval_module)
109-
110-
_cjac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(cjac_oop, cjac_iip)
11198
else
11299
_cjac = nothing
113100
end
114101

115102
M = calculate_massmatrix(sys)
116-
_M = if sparse && !(u0 === nothing || M === I)
117-
SparseArrays.sparse(M)
118-
elseif u0 === nothing || M === I
119-
M
120-
else
121-
ArrayInterface.restructure(u0 .* u0', M)
122-
end
103+
_M = concrete_massmatrix(M; sparse, u0)
123104

124105
observedfun = ObservedFunctionCache(
125106
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
126107

108+
_W_sparsity = W_sparsity(sys)
109+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
127110
if sparse
128111
uElType = u0 === nothing ? Float64 : eltype(u0)
129-
W_prototype = similar(W_sparsity(sys), uElType)
130112
controljac_prototype = similar(calculate_control_jacobian(sys), uElType)
131113
else
132-
W_prototype = nothing
133114
controljac_prototype = nothing
134115
end
135116

@@ -142,25 +123,11 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::System,
142123
jac_prototype = W_prototype,
143124
controljac_prototype = controljac_prototype,
144125
observed = observedfun,
145-
sparsity = sparsity ? W_sparsity(sys) : nothing,
126+
sparsity = sparsity ? _W_sparsity : nothing,
146127
analytic = analytic,
147128
initialization_data)
148129
end
149130

150-
function SciMLBase.ODEInputFunction(sys::System, args...; kwargs...)
151-
ODEInputFunction{true}(sys, args...; kwargs...)
152-
end
153-
154-
function SciMLBase.ODEInputFunction{true}(sys::System, args...;
155-
kwargs...)
156-
ODEInputFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
157-
end
158-
159-
function SciMLBase.ODEInputFunction{false}(sys::System, args...;
160-
kwargs...)
161-
ODEInputFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
162-
end
163-
164131
# returns the JuMP timespan, the number of steps, and whether it is a free time problem.
165132
function process_tspan(tspan, dt, steps)
166133
is_free_time = false

0 commit comments

Comments
 (0)