Skip to content

Commit 607c475

Browse files
Add mtk.jl for AutoModelingToolkit as AD backend support
1 parent ed38bee commit 607c475

File tree

5 files changed

+105
-33
lines changed

5 files changed

+105
-33
lines changed

lib/GalacticOptimJL/test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GalacticOptimJL, GalacticOptimJL.Optim, GalacticOptim, ForwardDiff, Zygote, Random
1+
using GalacticOptimJL, GalacticOptimJL.Optim, GalacticOptim, ForwardDiff, Zygote, Random, ModelingToolkit
22
using Test
33

44
@testset "GalacticOptimJL.jl" begin
@@ -90,4 +90,9 @@ using Test
9090
prob = OptimizationProblem(optprob, x0, _p; sense=GalacticOptim.MaxSense)
9191
sol = solve(prob, BFGS())
9292
@test 10 * sol.minimum < l1
93+
94+
optprob = OptimizationFunction(rosenbrock, GalacticOptim.AutoModelingToolkit())
95+
prob = OptimizationProblem(optprob, x0, _p)
96+
sol = solve(prob, Optim.BFGS())
97+
@test 10 * sol.minimum < l1
9398
end

src/GalacticOptim.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ include("function/function.jl")
2121

2222
function __init__()
2323
# AD backends
24-
@require FiniteDiff="6a86dc24-6348-571c-b903-95158fe2bd41" include("function/finitediff.jl")
25-
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("function/forwarddiff.jl")
26-
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("function/reversediff.jl")
27-
@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("function/tracker.jl")
28-
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("function/zygote.jl")
24+
@require FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" include("function/finitediff.jl")
25+
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("function/forwarddiff.jl")
26+
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("function/reversediff.jl")
27+
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("function/tracker.jl")
28+
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("function/zygote.jl")
29+
@require ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" include("function/mtk.jl")
2930
end
3031

3132
export solve

src/function/forwarddiff.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ function instantiate_function(f::OptimizationFunction{true}, x, ::AutoForwardDif
4343

4444
if f.cons === nothing
4545
cons = nothing
46-
cons! = nothing
4746
else
4847
cons = θ -> f.cons(θ,p)
49-
cons! = (res, θ) -> (res .= f.cons(θ,p); res)
5048
end
5149

5250
if cons !== nothing && f.cons_j === nothing

src/function/mtk.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
struct AutoModelingToolkit <: AbstractADType end
2+
3+
function instantiate_function(f, x, ::AutoModelingToolkit, p, num_cons=0)
4+
5+
sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p))
6+
println(sys)
7+
if f.grad === nothing
8+
grad_oop, grad_iip = ModelingToolkit.generate_gradient(sys, expression=Val{false})
9+
grad(J, u) = (grad_iip(J, u, p); J)
10+
else
11+
grad = f.grad
12+
end
13+
14+
if f.hess === nothing
15+
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression=Val{false})
16+
hess(J, u) = (hess_iip(J, u, p); J)
17+
else
18+
hess = f.hess
19+
end
20+
21+
if f.hv === nothing
22+
hv = function (H, θ, v, args...)
23+
res = ArrayInterface.zeromatrix(θ)
24+
hess(res, θ, args...)
25+
H .= res * v
26+
end
27+
else
28+
hv = f.hv
29+
end
30+
31+
if f.cons === nothing
32+
cons = nothing
33+
else
34+
cons = (θ) -> f.cons(θ, p)
35+
cons_sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f.cons, x, p); checks=false)
36+
end
37+
38+
if f.cons !== nothing && f.cons_j === nothing
39+
cons_j = function (J, θ)
40+
jac_oop, jac_iip = ModelingToolkit.generate_jacobian(cons_sys, expression=Val{false})
41+
jac_iip(J, θ, p)
42+
end
43+
else
44+
cons_j = f.cons_j
45+
end
46+
47+
if f.cons !== nothing && f.cons_h === nothing
48+
cons_h = function (res, θ)
49+
for i in 1:num_cons
50+
cons_sys_i = ModelingToolkit.modelingtoolkitize(OptimizationProblem((args...) -> f.cons(args...)[i], x, p); checks=false)
51+
cons_hess_oop, cons_hess_iip = ModelingToolkit.generate_hessian(cons_sys_i, expression=Val{false})
52+
cons_hess_iip(res[i], θ, p)
53+
end
54+
end
55+
else
56+
cons_h = f.cons_h
57+
end
58+
59+
return OptimizationFunction{true,AutoModelingToolkit,typeof(f.f),typeof(grad),typeof(hess),typeof(hv),typeof(cons),typeof(cons_j),typeof(cons_h)}(f.f, AutoModelingToolkit(), grad, hess, hv, cons, cons_j, cons_h)
60+
end

test/ADtests.jl

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using GalacticOptim, GalacticOptimJL, GalacticFlux, Test
22
using ForwardDiff, Zygote, ReverseDiff, FiniteDiff, Tracker
3-
3+
using ModelingToolkit
44
x0 = zeros(2)
5-
rosenbrock(x, p=nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
5+
rosenbrock(x, p=nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
66
l1 = rosenbrock(x0)
77

88
function g!(G, x)
@@ -17,16 +17,24 @@ function h!(H, x)
1717
H[2, 2] = 200.0
1818
end
1919

20-
G1 = Array{Float64}(undef,2)
21-
G2 = Array{Float64}(undef,2)
20+
G1 = Array{Float64}(undef, 2)
21+
G2 = Array{Float64}(undef, 2)
2222
H1 = Array{Float64}(undef, 2, 2)
2323
H2 = Array{Float64}(undef, 2, 2)
2424

2525
g!(G1, x0)
2626
h!(H1, x0)
2727

28+
cons = (x, p) -> [x[1]^2 + x[2]^2]
29+
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoModelingToolkit(), cons = cons)
30+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoModelingToolkit(), nothing)
31+
optprob.grad(G2, x0)
32+
@test G1 == G2
33+
optprob.hess(H2, x0)
34+
@test H1 == H2
35+
2836
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff())
29-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoForwardDiff(),nothing)
37+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoForwardDiff(), nothing)
3038
optprob.grad(G2, x0)
3139
@test G1 == G2
3240
optprob.hess(H2, x0)
@@ -35,16 +43,16 @@ optprob.hess(H2, x0)
3543
prob = OptimizationProblem(optprob, x0)
3644

3745
sol = solve(prob, Optim.BFGS())
38-
@test 10*sol.minimum < l1
46+
@test 10 * sol.minimum < l1
3947

4048
sol = solve(prob, Optim.Newton())
41-
@test 10*sol.minimum < l1
49+
@test 10 * sol.minimum < l1
4250

4351
sol = solve(prob, Optim.KrylovTrustRegion())
44-
@test 10*sol.minimum < l1
52+
@test 10 * sol.minimum < l1
4553

4654
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoZygote())
47-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoZygote(),nothing)
55+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoZygote(), nothing)
4856
optprob.grad(G2, x0)
4957
@test G1 == G2
5058
optprob.hess(H2, x0)
@@ -53,33 +61,33 @@ optprob.hess(H2, x0)
5361
prob = OptimizationProblem(optprob, x0)
5462

5563
sol = solve(prob, Optim.BFGS())
56-
@test 10*sol.minimum < l1
64+
@test 10 * sol.minimum < l1
5765

5866
sol = solve(prob, Optim.Newton())
59-
@test 10*sol.minimum < l1
67+
@test 10 * sol.minimum < l1
6068

6169
sol = solve(prob, Optim.KrylovTrustRegion())
62-
@test 10*sol.minimum < l1
70+
@test 10 * sol.minimum < l1
6371

6472
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoReverseDiff())
65-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoReverseDiff(),nothing)
73+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoReverseDiff(), nothing)
6674
optprob.grad(G2, x0)
6775
@test G1 == G2
6876
optprob.hess(H2, x0)
6977
@test H1 == H2
7078

7179
prob = OptimizationProblem(optprob, x0)
7280
sol = solve(prob, Optim.BFGS())
73-
@test 10*sol.minimum < l1
81+
@test 10 * sol.minimum < l1
7482

7583
sol = solve(prob, Optim.Newton())
76-
@test 10*sol.minimum < l1
84+
@test 10 * sol.minimum < l1
7785

7886
sol = solve(prob, Optim.KrylovTrustRegion())
79-
@test 10*sol.minimum < l1
87+
@test 10 * sol.minimum < l1
8088

8189
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoTracker())
82-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoTracker(),nothing)
90+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoTracker(), nothing)
8391
optprob.grad(G2, x0)
8492
@test G1 == G2
8593
@test_throws ErrorException optprob.hess(H2, x0)
@@ -88,26 +96,26 @@ optprob.grad(G2, x0)
8896
prob = OptimizationProblem(optprob, x0)
8997

9098
sol = solve(prob, Optim.BFGS())
91-
@test 10*sol.minimum < l1
99+
@test 10 * sol.minimum < l1
92100

93101
@test_throws ErrorException solve(prob, Newton())
94102

95103
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoFiniteDiff())
96-
optprob = GalacticOptim.instantiate_function(optf,x0,GalacticOptim.AutoFiniteDiff(),nothing)
104+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoFiniteDiff(), nothing)
97105
optprob.grad(G2, x0)
98-
@test G1 G2 rtol=1e-6
106+
@test G1 G2 rtol = 1e-6
99107
optprob.hess(H2, x0)
100-
@test H1 H2 rtol=1e-6
108+
@test H1 H2 rtol = 1e-6
101109

102110
prob = OptimizationProblem(optprob, x0)
103111
sol = solve(prob, Optim.BFGS())
104-
@test 10*sol.minimum < l1
112+
@test 10 * sol.minimum < l1
105113

106114
sol = solve(prob, Optim.Newton())
107-
@test 10*sol.minimum < l1
115+
@test 10 * sol.minimum < l1
108116

109117
sol = solve(prob, Optim.KrylovTrustRegion())
110118
@test sol.minimum < l1 #the loss doesn't go below 5e-1 here
111119

112-
sol = solve(prob, Flux.ADAM(0.1), maxiters = 1000)
113-
@test 10*sol.minimum < l1
120+
sol = solve(prob, Flux.ADAM(0.1), maxiters=1000)
121+
@test 10 * sol.minimum < l1

0 commit comments

Comments
 (0)