Skip to content

Commit bc763dc

Browse files
modelingtoolkitization and function generation
1 parent 702442b commit bc763dc

File tree

2 files changed

+59
-4
lines changed

2 files changed

+59
-4
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,27 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem)
8484

8585
de
8686
end
87+
88+
89+
"""
90+
$(TYPEDSIGNATURES)
91+
92+
Generate `OptimizationSystem`, dependent variables, and parameters from an `OptimizationProblem`.
93+
"""
94+
function modelingtoolkitize(prob::DiffEqBase.OptimizationProblem)
95+
96+
if prob.p isa Tuple || prob.p isa NamedTuple
97+
p = [x for x in prob.p]
98+
else
99+
p = prob.p
100+
end
101+
102+
vars = reshape([Variable(:x, i)(t) for i in eachindex(prob.u0)],size(prob.u0))
103+
params = p isa DiffEqBase.NullParameters ? [] :
104+
reshape([Variable(,i)() for i in eachindex(p)],size(Array(p)))
105+
106+
107+
eqs = prob.f(vars, params)
108+
de = OptimizationSystem(eqs,vec(vars),vec(params))
109+
de
110+
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ namespace_operation(sys::OptimizationSystem) = namespace_operation(sys.op,sys.na
8484
hessian_sparsity(sys::OptimizationSystem) =
8585
hessian_sparsity(sys.op,[dv() for dv in states(sys)])
8686

87-
struct ManualModelingToolkit <: DiffEqBase.AbstractADType end
87+
struct AutoModelingToolkit <: DiffEqBase.AbstractADType end
8888

8989
"""
9090
```julia
@@ -134,7 +134,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0,
134134
_hess = nothing
135135
end
136136

137-
_f = OptimizationFunction{iip,typeof(f),typeof(_grad),typeof(_hess),Nothing,Nothing,Nothing,Nothing}(f,_grad,_hess,nothing,ManualModelingToolkit(),nothing,nothing,nothing,0)
137+
_f = OptimizationFunction{iip,typeof(f),typeof(_grad),typeof(_hess),Nothing,Nothing,Nothing,Nothing}(f,_grad,_hess,nothing,AutoModelingToolkit(),nothing,nothing,nothing,0)
138138

139139
p = varmap_to_vars(parammap,ps)
140140
lb = varmap_to_vars(lb,dvs)
@@ -170,9 +170,23 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
170170
kwargs...) where iip
171171
dvs = states(sys)
172172
ps = parameters(sys)
173-
173+
idx = iip ? 2 : 1
174174
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
175175
expression=Val{true})
176+
if grad
177+
_grad = generate_gradient(sys,checkbounds=checkbounds,linenumbers=linenumbers,
178+
parallel=parallel,expression=Val{false})[idx]
179+
else
180+
_grad = :nothing
181+
end
182+
183+
if hess
184+
_hess = generate_hessian(sys,checkbounds=checkbounds,linenumbers=linenumbers,
185+
sparse=sparse,parallel=parallel,expression=Val{false})[idx]
186+
else
187+
_hess = :nothing
188+
end
189+
176190
u0 = varmap_to_vars(u0,dvs)
177191
p = varmap_to_vars(parammap,ps)
178192
lb = varmap_to_vars(lb,dvs)
@@ -181,8 +195,25 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
181195
f = $f
182196
p = $p
183197
u0 = $u0
198+
grad = $_grad
199+
hess = $_hess
184200
lb = $lb
185201
ub = $ub
186-
OptimizationProblem(f,u0,p;lb=lb,ub=ub,kwargs...)
202+
_f = OptimizationFunction{$iip,typeof(f),typeof(grad),typeof(hess),Nothing,Nothing,Nothing,Nothing}(f,grad,hess,nothing,AutoModelingToolkit(),nothing,nothing,nothing,0)
203+
OptimizationProblem{$iip}(_f,u0,p;lb=lb,ub=ub,kwargs...)
204+
end
205+
end
206+
207+
function OptimizationFunction(f, x, ::AutoModelingToolkit,p = DiffEqBase.NullParameters();
208+
grad=nothing, hess=nothing, cons = nothing, cons_j = nothing, cons_h = nothing,
209+
num_cons = 0, chunksize = 1, hv = nothing)
210+
211+
sys = modelingtoolkitize(OptimizationProblem(f,x,p))
212+
u0map = states(sys) .=> x
213+
if p == DiffEqBase.NullParameters()
214+
parammap = DiffEqBase.NullParameters()
215+
else
216+
parammap = parameters(sys) .=> p
187217
end
218+
OptimiationProblem(sys,u0map,parammap,grad=grad,hess=hess)
188219
end

0 commit comments

Comments
 (0)