Skip to content

Commit 997f973

Browse files
fix generation functions, tests now pass
1 parent 7951dbb commit 997f973

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ hessian_sparsity(sys::OptimizationSystem) =
8686

8787
struct AutoModelingToolkit <: DiffEqBase.AbstractADType end
8888

89+
DiffEqBase.OptimizationProblem(sys::OptimizationSystem,args...;kwargs...) =
90+
DiffEqBase.OptimizationProblem{true}(sys::OptimizationSystem,args...;kwargs...)
91+
8992
"""
9093
```julia
9194
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
@@ -120,7 +123,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0,
120123
grad_oop,grad_iip = generate_gradient(sys,checkbounds=checkbounds,linenumbers=linenumbers,
121124
parallel=parallel,expression=Val{false})
122125
_grad(u,p) = grad_oop(u,p)
123-
_grad(J,u,p) = grad_iip(J,u,p)
126+
_grad(J,u,p) = (grad_iip(J,u,p); J)
124127
else
125128
_grad = nothing
126129
end
@@ -129,7 +132,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0,
129132
hess_oop,hess_iip = generate_hessian(sys,checkbounds=checkbounds,linenumbers=linenumbers,
130133
sparse=sparse,parallel=parallel,expression=Val{false})
131134
_hess(u,p) = hess_oop(u,p)
132-
_hess(J,u,p) = hess_iip(J,u,p)
135+
_hess(J,u,p) = (hess_iip(J,u,p); J)
133136
else
134137
_hess = nothing
135138
end

test/modelingtoolkitize.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ p = [1.0,100.0]
4949

5050
prob = OptimizationProblem(rosenbrock,x0,p)
5151
sys = modelingtoolkitize(prob)
52+
x0map = states(sys) .=> x0
53+
parammap = parameters(sys) .=> p
5254

53-
prob = OptimizationProblem(ModelingToolkit.OptimizationFunction(
54-
rosenbrock,x0,ModelingToolkit.AutoModelingToolkit(),p,
55-
grad = true), x0,p)
55+
prob = OptimizationProblem(sys,x0map,parammap,grad=true)
5656
sol = solve(prob,NelderMead())
5757
@test sol.minimum < 1e-8
5858

5959
sol = solve(prob,BFGS())
60-
@test_broken sol.minimum < 1e-8
60+
@test sol.minimum < 1e-8

test/optimizationsystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, SparseArrays
1+
using ModelingToolkit, SparseArrays, Test, GalacticOptim, Optim
22

33
@variables x y
44
@parameters a b
@@ -36,7 +36,13 @@ p = [
3636
sys2.b => 9.0
3737
β => 10.0
3838
]
39+
3940
prob = OptimizationProblem(combinedsys,u0,p,grad=true)
41+
sol = solve(prob,NelderMead())
42+
@test sol.minimum < -1e5
4043

41-
using GalacticOptim, Optim
42-
solve(prob,BFGS())
44+
prob2 = remake(prob,u0=sol.minimizer)
45+
sol = solve(prob,BFGS(initial_stepnorm=0.0001),allow_f_increases=true)
46+
@test sol.minimum < -1e8
47+
sol = solve(prob2,BFGS(initial_stepnorm=0.0001),allow_f_increases=true)
48+
@test sol.minimum < -1e9

0 commit comments

Comments
 (0)