Skip to content

Commit 702442b

Browse files
Manual MTK works
1 parent 8ee5d30 commit 702442b

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,15 @@ namespace_operation(sys::OptimizationSystem) = namespace_operation(sys.op,sys.na
8484
hessian_sparsity(sys::OptimizationSystem) =
8585
hessian_sparsity(sys.op,[dv() for dv in states(sys)])
8686

87+
struct ManualModelingToolkit <: DiffEqBase.AbstractADType end
88+
8789
"""
8890
```julia
8991
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
9092
parammap=DiffEqBase.NullParameters();
9193
u0=nothing, lb=nothing, ub=nothing,
92-
hes = false, sparse = false,
94+
grad = false,
95+
hess = false, sparse = false,
9396
checkbounds = false,
9497
linenumbers = true, parallel=SerialForm(),
9598
kwargs...) where iip
@@ -98,30 +101,53 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
98101
Generates an OptimizationProblem from an OptimizationSystem and allows for automatically
99102
symbolically calculating numerical enhancements.
100103
"""
101-
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
104+
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0,
102105
parammap=DiffEqBase.NullParameters();
103-
u0=nothing, lb=nothing, ub=nothing,
104-
hes = false, sparse = false,
106+
lb=nothing, ub=nothing,
107+
grad = false,
108+
hess = false, sparse = false,
105109
checkbounds = false,
106110
linenumbers = true, parallel=SerialForm(),
107111
kwargs...) where iip
108112
dvs = states(sys)
109113
ps = parameters(sys)
110114

111115
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
112-
parallel=parallel,expression=Val{false})
116+
expression=Val{false})
113117
u0 = varmap_to_vars(u0,dvs)
118+
119+
if grad
120+
grad_oop,grad_iip = generate_gradient(sys,checkbounds=checkbounds,linenumbers=linenumbers,
121+
parallel=parallel,expression=Val{false})
122+
_grad(u,p) = grad_oop(u,p)
123+
_grad(J,u,p) = grad_iip(J,u,p)
124+
else
125+
_grad = nothing
126+
end
127+
128+
if hess
129+
hess_oop,hess_iip = generate_hessian(sys,checkbounds=checkbounds,linenumbers=linenumbers,
130+
sparse=sparse,parallel=parallel,expression=Val{false})
131+
_hess(u,p) = hess_oop(u,p)
132+
_hess(J,u,p) = hess_iip(J,u,p)
133+
else
134+
_hess = nothing
135+
end
136+
137+
_f = OptimizationFunction{iip,typeof(f),typeof(_grad),typeof(_hess),Nothing,Nothing,Nothing,Nothing}(f,_grad,_hess,nothing,ManualModelingToolkit(),nothing,nothing,nothing,0)
138+
114139
p = varmap_to_vars(parammap,ps)
115140
lb = varmap_to_vars(lb,dvs)
116141
ub = varmap_to_vars(ub,dvs)
117-
OptimizationProblem(f,p;u0=u0,lb=lb,ub=ub,kwargs...)
142+
OptimizationProblem{iip}(_f,u0,p;lb=lb,ub=ub,kwargs...)
118143
end
119144

120145
"""
121146
```julia
122147
function DiffEqBase.OptimizationProblemExpr{iip}(sys::OptimizationSystem,
123148
parammap=DiffEqBase.NullParameters();
124149
u0=nothing, lb=nothing, ub=nothing,
150+
grad = false,
125151
hes = false, sparse = false,
126152
checkbounds = false,
127153
linenumbers = true, parallel=SerialForm(),
@@ -134,9 +160,10 @@ calculating numerical enhancements.
134160
"""
135161
struct OptimizationProblemExpr{iip} end
136162

137-
function OptimizationProblemExpr{iip}(sys::OptimizationSystem,
163+
function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
138164
parammap=DiffEqBase.NullParameters();
139-
u0=nothing, lb=nothing, ub=nothing,
165+
lb=nothing, ub=nothing,
166+
grad = true,
140167
hes = false, sparse = false,
141168
checkbounds = false,
142169
linenumbers = false, parallel=SerialForm(),
@@ -145,7 +172,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem,
145172
ps = parameters(sys)
146173

147174
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
148-
parallel=parallel,expression=Val{true})
175+
expression=Val{true})
149176
u0 = varmap_to_vars(u0,dvs)
150177
p = varmap_to_vars(parammap,ps)
151178
lb = varmap_to_vars(lb,dvs)
@@ -156,6 +183,6 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem,
156183
u0 = $u0
157184
lb = $lb
158185
ub = $ub
159-
OptimizationProblem(f,p;u0=u0,lb=lb,ub=ub,kwargs...)
186+
OptimizationProblem(f,u0,p;lb=lb,ub=ub,kwargs...)
160187
end
161188
end

test/optimizationsystem.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,22 @@ generate_function(combinedsys)
2121
generate_gradient(combinedsys)
2222
generate_hessian(combinedsys)
2323
ModelingToolkit.hessian_sparsity(combinedsys)
24+
25+
u0 = [
26+
sys1.x=>1.0
27+
sys1.y=>2.0
28+
sys2.x=>3.0
29+
sys2.y=>4.0
30+
z=>5.0
31+
]
32+
p = [
33+
sys1.a => 6.0
34+
sys1.b => 7.0
35+
sys2.a => 8.0
36+
sys2.b => 9.0
37+
β => 10.0
38+
]
39+
prob = OptimizationProblem(combinedsys,u0,p,grad=true)
40+
41+
using GalacticOptim, Optim
42+
solve(prob,BFGS())

0 commit comments

Comments
 (0)