Skip to content

Commit 7dbf18c

Browse files
Removes OptimizationFunction and add GalacticOptim as a dependency (#1563)
1 parent ee8890b commit 7dbf18c

File tree

8 files changed

+74
-28
lines changed

8 files changed

+74
-28
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
8383
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8484
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
8585
GalacticOptimJL = "9d3c5eb1-403b-401b-8c0f-c11105342e6b"
86-
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
8786
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8887
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8988
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
@@ -94,4 +93,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
9493
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9594

9695
[targets]
97-
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "GalacticOptimJL", "OrdinaryDiffEq", "Optim", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
96+
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "GalacticOptimJL", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ using DocStringExtensions
2525
using Base: RefValue
2626
using Combinatorics
2727
import IfElse
28-
2928
import Distributions
3029

3130
RuntimeGeneratedFunctions.init(@__MODULE__)
@@ -133,6 +132,7 @@ include("systems/diffeqs/basic_transformations.jl")
133132
include("systems/jumps/jumpsystem.jl")
134133

135134
include("systems/nonlinear/nonlinearsystem.jl")
135+
include("systems/nonlinear/modelingtoolkitize.jl")
136136

137137
include("systems/optimization/optimizationsystem.jl")
138138

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
$(TYPEDSIGNATURES)
3+
4+
Generate `NonlinearSystem`, dependent variables, and parameters from an `NonlinearProblem`.
5+
"""
6+
function modelingtoolkitize(prob::NonlinearProblem; kwargs...)
7+
p = prob.p
8+
has_p = !(p isa Union{DiffEqBase.NullParameters,Nothing})
9+
10+
_vars = reshape([variable(:x, i) for i in eachindex(prob.u0)], size(prob.u0))
11+
12+
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0, _vars)
13+
params = if has_p
14+
_params = define_params(p)
15+
p isa Number ? _params[1] : (p isa Tuple || p isa NamedTuple ? _params : ArrayInterface.restructure(p, _params))
16+
else
17+
[]
18+
end
19+
20+
if DiffEqBase.isinplace(prob)
21+
rhs = ArrayInterface.restructure(prob.u0, similar(vars, Num))
22+
prob.f(rhs, vars, params)
23+
else
24+
rhs = prob.f(vars, params)
25+
end
26+
out_def = prob.f(prob.u0, prob.p)
27+
eqs = vcat([0.0 ~ rhs[i] for i in 1:length(out_def)]...)
28+
29+
sts = vec(collect(vars))
30+
31+
params = if params isa Number || (params isa Array && ndims(params) == 0)
32+
[params[1]]
33+
else
34+
vec(collect(params))
35+
end
36+
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
37+
default_p = has_p ? Dict(params .=> vec(collect(prob.p))) : Dict()
38+
39+
de = NonlinearSystem(
40+
eqs, sts, params,
41+
defaults=merge(default_u0, default_p);
42+
name=gensym(:MTKizedNonlinProb),
43+
kwargs...
44+
)
45+
46+
de
47+
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = paramete
137137
return build_function(jac, vs, ps; kwargs...)
138138
end
139139

140+
function calculate_hessian(sys::NonlinearSystem; sparse=false, simplify=false)
141+
rhs = [eq.rhs for eq equations(sys)]
142+
vals = [dv for dv in states(sys)]
143+
if sparse
144+
hess = [sparsehessian(rhs[i], vals, simplify=simplify) for i in 1:length(rhs)]
145+
else
146+
hess = [hessian(rhs[i], vals, simplify=simplify) for i in 1:length(rhs)]
147+
end
148+
return hess
149+
end
150+
151+
function generate_hessian(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys);
152+
sparse = false, simplify=false, kwargs...)
153+
hess = calculate_hessian(sys,sparse=sparse, simplify=simplify)
154+
return build_function(hess, vs, ps; kwargs...)
155+
end
156+
140157
function generate_function(sys::NonlinearSystem, dvs = states(sys), ps = parameters(sys); kwargs...)
141158
rhss = [deq.rhs for deq equations(sys)]
142159
pre, sol_states = get_substitutions_and_solved_states(sys)

src/systems/optimization/optimizationsystem.jl

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,11 @@ function OptimizationSystem(op, states, ps;
7878
process_variables!(var_to_name, defaults, states)
7979
process_variables!(var_to_name, defaults, ps)
8080
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
81-
8281
OptimizationSystem(
8382
value(op), states, ps, var_to_name,
8483
observed,
8584
equality_constraints, inequality_constraints,
86-
name, systems, defaults, checks = checks
85+
name, systems, defaults; checks = checks
8786
)
8887
end
8988

@@ -122,8 +121,6 @@ namespace_expr(sys::OptimizationSystem) = namespace_expr(get_op(sys), sys)
122121

123122
hessian_sparsity(sys::OptimizationSystem) = hessian_sparsity(get_op(sys), states(sys))
124123

125-
struct AutoModelingToolkit <: DiffEqBase.AbstractADType end
126-
127124
DiffEqBase.OptimizationProblem(sys::OptimizationSystem,args...;kwargs...) =
128125
DiffEqBase.OptimizationProblem{true}(sys::OptimizationSystem,args...;kwargs...)
129126

@@ -175,7 +172,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
175172
_hess = nothing
176173
end
177174

178-
_f = DiffEqBase.OptimizationFunction{iip,AutoModelingToolkit,typeof(f),typeof(_grad),typeof(_hess),Nothing,Nothing,Nothing,Nothing}(f,AutoModelingToolkit(),_grad,_hess,nothing,nothing,nothing,nothing)
175+
_f = DiffEqBase.OptimizationFunction{iip,SciMLBase.NoAD,typeof(f),typeof(_grad),typeof(_hess),Nothing,Nothing,Nothing,Nothing}(f, SciMLBase.NoAD(), _grad, _hess, nothing, nothing, nothing, nothing)
179176

180177
defs = defaults(sys)
181178
defs = mergedefaults(defs,parammap,ps)
@@ -253,21 +250,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
253250
hess = $_hess
254251
lb = $lb
255252
ub = $ub
256-
_f = OptimizationFunction{$iip,typeof(f),typeof(grad),typeof(hess),Nothing,Nothing,Nothing,Nothing}(f,grad,hess,nothing,AutoModelingToolkit(),nothing,nothing,nothing,0)
253+
_f = OptimizationFunction{$iip,typeof(f),typeof(grad),typeof(hess),SciMLBase.NoAD,Nothing,Nothing,Nothing}(f,grad,hess,nothing,SciMLBase.NoAD(),nothing,nothing,nothing,0)
257254
OptimizationProblem{$iip}(_f,u0,p;lb=lb,ub=ub,kwargs...)
258255
end
259256
end
260-
261-
function DiffEqBase.OptimizationFunction{iip}(f, ::AutoModelingToolkit, x, p = DiffEqBase.NullParameters();
262-
grad=false, hess=false, cons = nothing, cons_j = nothing, cons_h = nothing,
263-
num_cons = 0, chunksize = 1, hv = nothing) where iip
264-
265-
sys = modelingtoolkitize(OptimizationProblem(f,x,p))
266-
u0map = states(sys) .=> x
267-
if p == DiffEqBase.NullParameters()
268-
parammap = DiffEqBase.NullParameters()
269-
else
270-
parammap = parameters(sys) .=> p
271-
end
272-
OptimizationProblem(sys,u0map,parammap,grad=grad,hess=hess).f
273-
end

test/controlsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, GalacticOptim, Optim, GalacticOptimJL
1+
using ModelingToolkit, GalacticOptim, GalacticOptimJL
22

33
@variables t x(t) v(t) u(t)
44
@parameters p[1:2]

test/modelingtoolkitize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using OrdinaryDiffEq, ModelingToolkit, Test
2-
using GalacticOptim, Optim, RecursiveArrayTools, GalacticOptimJL
2+
using GalacticOptim, RecursiveArrayTools, GalacticOptimJL
33

44
N = 32
55
const xyd_brusselator = range(0,stop=1,length=N)

test/optimizationsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, SparseArrays, Test, GalacticOptim, Optim, GalacticOptimJL
1+
using ModelingToolkit, SparseArrays, Test, GalacticOptim, GalacticOptimJL
22

33
@variables x y
44
@parameters a b
@@ -51,9 +51,9 @@ rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
5151
x0 = zeros(2)
5252
_p = [1.0, 100.0]
5353

54-
f = OptimizationFunction(rosenbrock,ModelingToolkit.AutoModelingToolkit(),x0,_p,grad=true,hess=true)
54+
f = OptimizationFunction(rosenbrock,GalacticOptim.AutoModelingToolkit())
5555
prob = OptimizationProblem(f,x0,_p)
56-
sol = solve(prob,Optim.Newton())
56+
sol = solve(prob,Newton())
5757

5858
# issue #819
5959
@testset "Combined system name collisions" begin

0 commit comments

Comments
 (0)