Skip to content

Commit 61556cc

Browse files
Merge pull request #237 from SciML/automtkoptfun
Add mtk.jl for `AutoModelingToolkit` as AD backend support
2 parents ed38bee + d7edbad commit 61556cc

File tree

6 files changed

+133
-35
lines changed

6 files changed

+133
-35
lines changed

lib/GalacticOptimJL/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
99
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1010

1111
[compat]
12-
julia = "1"
1312
GalacticOptim = "3"
1413
Optim = "1"
1514
Reexport = "1.2"
15+
julia = "1"
1616

1717
[extras]
18+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1819
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2021
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2122
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2223

2324
[targets]
24-
test = ["ForwardDiff", "Random", "Test", "Zygote"]
25+
test = ["ForwardDiff", "ModelingToolkit", "Random", "Test", "Zygote"]

lib/GalacticOptimJL/test/runtests.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GalacticOptimJL, GalacticOptimJL.Optim, GalacticOptim, ForwardDiff, Zygote, Random
1+
using GalacticOptimJL, GalacticOptimJL.Optim, GalacticOptim, ForwardDiff, Zygote, Random, ModelingToolkit
22
using Test
33

44
@testset "GalacticOptimJL.jl" begin
@@ -23,6 +23,7 @@ using Test
2323

2424
cons = (x, p) -> [x[1]^2 + x[2]^2]
2525
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff(); cons=cons)
26+
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoModelingToolkit(); cons=cons)
2627

2728
prob = OptimizationProblem(optprob, x0, _p)
2829

@@ -90,4 +91,9 @@ using Test
9091
prob = OptimizationProblem(optprob, x0, _p; sense=GalacticOptim.MaxSense)
9192
sol = solve(prob, BFGS())
9293
@test 10 * sol.minimum < l1
94+
95+
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoModelingToolkit())
96+
prob = OptimizationProblem(optprob, x0, _p)
97+
sol = solve(prob, Optim.BFGS())
98+
@test 10 * sol.minimum < l1
9399
end

src/GalacticOptim.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ include("function/function.jl")
2121

2222
function __init__()
2323
# AD backends
24-
@require FiniteDiff="6a86dc24-6348-571c-b903-95158fe2bd41" include("function/finitediff.jl")
25-
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("function/forwarddiff.jl")
26-
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("function/reversediff.jl")
27-
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("function/tracker.jl")
28-
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("function/zygote.jl")
24+
@require FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" include("function/finitediff.jl")
25+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("function/forwarddiff.jl")
26+
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("function/reversediff.jl")
27+
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("function/tracker.jl")
28+
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("function/zygote.jl")
29+
@require ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" include("function/mtk.jl")
2930
end
3031

3132
export solve

src/function/forwarddiff.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ function instantiate_function(f::OptimizationFunction{true}, x, ::AutoForwardDif
4343

4444
if f.cons === nothing
4545
cons = nothing
46-
cons! = nothing
4746
else
4847
cons = θ -> f.cons(θ,p)
49-
cons! = (res, θ) -> (res .= f.cons(θ,p); res)
5048
end
5149

5250
if cons !== nothing && f.cons_j === nothing

src/function/mtk.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
struct AutoModelingToolkit <: AbstractADType end
2+
3+
function instantiate_function(f, x, ::AutoModelingToolkit, p, num_cons=0)
4+
p = isnothing(p) ? SciMLBase.NullParameters() : p
5+
sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p))
6+
7+
if f.grad === nothing
8+
grad_oop, grad_iip = ModelingToolkit.generate_gradient(sys, expression=Val{false})
9+
grad(J, u) = (grad_iip(J, u, p); J)
10+
else
11+
grad = f.grad
12+
end
13+
14+
if f.hess === nothing
15+
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression=Val{false})
16+
hess(J, u) = (hess_iip(J, u, p); J)
17+
else
18+
hess = f.hess
19+
end
20+
21+
if f.hv === nothing
22+
hv = function (H, θ, v, args...)
23+
res = ArrayInterface.zeromatrix(θ)
24+
hess(res, θ, args...)
25+
H .= res * v
26+
end
27+
else
28+
hv = f.hv
29+
end
30+
31+
if f.cons === nothing
32+
cons = nothing
33+
else
34+
cons = (θ) -> f.cons(θ, p)
35+
cons_sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f.cons, x, p); checks=false)
36+
end
37+
38+
if f.cons !== nothing && f.cons_j === nothing
39+
cons_j = function (J, θ)
40+
jac_oop, jac_iip = ModelingToolkit.generate_jacobian(cons_sys, expression=Val{false})
41+
jac_iip(J, θ, p)
42+
end
43+
else
44+
cons_j = f.cons_j
45+
end
46+
47+
if f.cons !== nothing && f.cons_h === nothing
48+
cons_h = function (res, θ)
49+
for i in 1:num_cons
50+
cons_sys_i = ModelingToolkit.modelingtoolkitize(OptimizationProblem((args...) -> f.cons(args...)[i], x, p); checks=false)
51+
cons_hess_oop, cons_hess_iip = ModelingToolkit.generate_hessian(cons_sys_i, expression=Val{false})
52+
cons_hess_iip(res[i], θ, p)
53+
end
54+
end
55+
else
56+
cons_h = f.cons_h
57+
end
58+
59+
return OptimizationFunction{true,AutoModelingToolkit,typeof(f.f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h)}(f.f, AutoModelingToolkit(), grad, hess, hv, cons, cons_j, cons_h)
60+
end

test/ADtests.jl

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using GalacticOptim, GalacticOptimJL, GalacticFlux, Test
22
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
3-
3+
using ModelingToolkit
44
x0 = zeros(2)
5-
rosenbrock(x, p=nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
5+
rosenbrock(x, p=nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
66
l1 = rosenbrock(x0)
77

88
function g!(G, x)
@@ -17,16 +17,48 @@ function h!(H, x)
1717
H[2, 2] = 200.0
1818
end
1919

20-
G1 = Array{Float64}(undef,2)
21-
G2 = Array{Float64}(undef,2)
20+
G1 = Array{Float64}(undef, 2)
21+
G2 = Array{Float64}(undef, 2)
2222
H1 = Array{Float64}(undef, 2, 2)
2323
H2 = Array{Float64}(undef, 2, 2)
2424

2525
g!(G1, x0)
2626
h!(H1, x0)
2727

28+
cons = (x, p) -> [x[1]^2 + x[2]^2]
29+
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoModelingToolkit(), cons=cons)
30+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoModelingToolkit(), nothing, 1)
31+
optprob.grad(G2, x0)
32+
@test G1 == G2
33+
optprob.hess(H2, x0)
34+
@test H1 == H2
35+
@test optprob.cons(x0) == [0.0]
36+
J = Array{Float64}(undef, 2)
37+
optprob.cons_j(J, [5.0, 3.0])
38+
@test J == [10.0, 6.0]
39+
H3 = [Array{Float64}(undef, 2, 2)]
40+
optprob.cons_h(H3, x0)
41+
@test H3 == [[2.0 0.0; 0.0 2.0]]
42+
43+
function con2_c(x, p)
44+
[x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
45+
end
46+
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoModelingToolkit(), cons=con2_c)
47+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoModelingToolkit(), nothing, 2)
48+
optprob.grad(G2, x0)
49+
@test G1 == G2
50+
optprob.hess(H2, x0)
51+
@test H1 == H2
52+
@test optprob.cons(x0) == [0.0, 0.0]
53+
J = Array{Float64}(undef, 2, 2)
54+
optprob.cons_j(J, [5.0, 3.0])
55+
@test all(isapprox(J, [10.0 6.0; -0.149013 -0.958924]; rtol=1e-3))
56+
H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
57+
optprob.cons_h(H3, x0)
58+
@test H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
59+
2860
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff())
29-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoForwardDiff(),nothing)
61+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoForwardDiff(), nothing)
3062
optprob.grad(G2, x0)
3163
@test G1 == G2
3264
optprob.hess(H2, x0)
@@ -35,16 +67,16 @@ optprob.hess(H2, x0)
3567
prob = OptimizationProblem(optprob, x0)
3668

3769
sol = solve(prob, Optim.BFGS())
38-
@test 10*sol.minimum < l1
70+
@test 10 * sol.minimum < l1
3971

4072
sol = solve(prob, Optim.Newton())
41-
@test 10*sol.minimum < l1
73+
@test 10 * sol.minimum < l1
4274

4375
sol = solve(prob, Optim.KrylovTrustRegion())
44-
@test 10*sol.minimum < l1
76+
@test 10 * sol.minimum < l1
4577

4678
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoZygote())
47-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoZygote(),nothing)
79+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoZygote(), nothing)
4880
optprob.grad(G2, x0)
4981
@test G1 == G2
5082
optprob.hess(H2, x0)
@@ -53,33 +85,33 @@ optprob.hess(H2, x0)
5385
prob = OptimizationProblem(optprob, x0)
5486

5587
sol = solve(prob, Optim.BFGS())
56-
@test 10*sol.minimum < l1
88+
@test 10 * sol.minimum < l1
5789

5890
sol = solve(prob, Optim.Newton())
59-
@test 10*sol.minimum < l1
91+
@test 10 * sol.minimum < l1
6092

6193
sol = solve(prob, Optim.KrylovTrustRegion())
62-
@test 10*sol.minimum < l1
94+
@test 10 * sol.minimum < l1
6395

6496
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoReverseDiff())
65-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoReverseDiff(),nothing)
97+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoReverseDiff(), nothing)
6698
optprob.grad(G2, x0)
6799
@test G1 == G2
68100
optprob.hess(H2, x0)
69101
@test H1 == H2
70102

71103
prob = OptimizationProblem(optprob, x0)
72104
sol = solve(prob, Optim.BFGS())
73-
@test 10*sol.minimum < l1
105+
@test 10 * sol.minimum < l1
74106

75107
sol = solve(prob, Optim.Newton())
76-
@test 10*sol.minimum < l1
108+
@test 10 * sol.minimum < l1
77109

78110
sol = solve(prob, Optim.KrylovTrustRegion())
79-
@test 10*sol.minimum < l1
111+
@test 10 * sol.minimum < l1
80112

81113
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoTracker())
82-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoTracker(),nothing)
114+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoTracker(), nothing)
83115
optprob.grad(G2, x0)
84116
@test G1 == G2
85117
@test_throws ErrorException optprob.hess(H2, x0)
@@ -88,26 +120,26 @@ optprob.grad(G2, x0)
88120
prob = OptimizationProblem(optprob, x0)
89121

90122
sol = solve(prob, Optim.BFGS())
91-
@test 10*sol.minimum < l1
123+
@test 10 * sol.minimum < l1
92124

93125
@test_throws ErrorException solve(prob, Newton())
94126

95127
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoFiniteDiff())
96-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoFiniteDiff(),nothing)
128+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoFiniteDiff(), nothing)
97129
optprob.grad(G2, x0)
98-
@test G1 G2 rtol=1e-6
130+
@test G1 G2 rtol = 1e-6
99131
optprob.hess(H2, x0)
100-
@test H1 H2 rtol=1e-6
132+
@test H1 H2 rtol = 1e-6
101133

102134
prob = OptimizationProblem(optprob, x0)
103135
sol = solve(prob, Optim.BFGS())
104-
@test 10*sol.minimum < l1
136+
@test 10 * sol.minimum < l1
105137

106138
sol = solve(prob, Optim.Newton())
107-
@test 10*sol.minimum < l1
139+
@test 10 * sol.minimum < l1
108140

109141
sol = solve(prob, Optim.KrylovTrustRegion())
110142
@test sol.minimum < l1 #the loss doesn't go below 5e-1 here
111143

112-
sol = solve(prob, Flux.ADAM(0.1), maxiters = 1000)
113-
@test 10*sol.minimum < l1
144+
sol = solve(prob, Flux.ADAM(0.1), maxiters=1000)
145+
@test 10 * sol.minimum < l1

0 commit comments

Comments
 (0)