Skip to content

Commit c034718

Browse files
Merge pull request #432 from SciML/mtkize
Switch to using `modelingtoolkitize` inside `AutoModelingToolkit`
2 parents ddb0279 + 3cc8af7 commit c034718

File tree

2 files changed

+24
-88
lines changed

2 files changed

+24
-88
lines changed

lib/OptimizationMOI/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ version = "0.1.7"
66
[deps]
77
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
88
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
9-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
109
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1111

1212
[compat]
1313
MathOptInterface = "1"

src/function/mtk.jl

Lines changed: 23 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -7,101 +7,37 @@ AutoModelingToolkit() = AutoModelingToolkit(false, false)
77

88
function instantiate_function(f, x, adtype::AutoModelingToolkit, p, num_cons = 0)
99
p = isnothing(p) ? SciMLBase.NullParameters() : p
10-
sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p))
1110

12-
hess_prototype, cons_jac_prototype, cons_hess_prototype = nothing, nothing, nothing
11+
sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p;
12+
lcons = fill(0.0,
13+
num_cons),
14+
ucons = fill(0.0,
15+
num_cons)))
16+
sys = ModelingToolkit.structural_simplify(sys)
17+
f = OptimizationProblem(sys, x, p, grad = true, hess = true,
18+
sparse = adtype.obj_sparse, cons_j = true, cons_h = true,
19+
cons_sparse = adtype.cons_sparse).f
1320

14-
if f.grad === nothing
15-
grad_oop, grad_iip = ModelingToolkit.generate_gradient(sys, expression = Val{false})
16-
grad(J, u) = (grad_iip(J, u, p); J)
17-
else
18-
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
19-
end
20-
21-
if f.hess === nothing
22-
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression = Val{false},
23-
sparse = adtype.obj_sparse)
24-
hess(H, u) = (hess_iip(H, u, p); H)
25-
else
26-
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
27-
end
28-
29-
if f.hv === nothing
30-
hv = function (H, θ, v, args...)
31-
res = adtype.obj_sparse ? hess_prototype : ArrayInterfaceCore.zeromatrix(θ)
32-
hess(res, θ, args...)
33-
H .= res * v
34-
end
35-
else
36-
hv = f.hv
37-
end
38-
39-
expr = symbolify(ModelingToolkit.Symbolics.toexpr(ModelingToolkit.equations(sys)))
40-
pairs_arr = p isa SciMLBase.NullParameters ?
41-
[Symbol(_s) => Expr(:ref, :x, i) for (i, _s) in enumerate(sys.states)] :
42-
[
43-
[Symbol(_s) => Expr(:ref, :x, i) for (i, _s) in enumerate(sys.states)]...,
44-
[Symbol(_p) => p[i] for (i, _p) in enumerate(sys.ps)]...,
45-
]
46-
rep_pars_vals!(expr, pairs_arr)
21+
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
4722

48-
if f.cons === nothing
49-
cons = nothing
50-
cons_exprs = nothing
51-
else
52-
cons = (res, θ) -> f.cons(res, θ, p)
53-
cons_oop = (x, p) -> (_res = zeros(eltype(x), num_cons); f.cons(_res, x, p); _res)
23+
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
5424

55-
cons_sys = ModelingToolkit.modelingtoolkitize(NonlinearProblem(cons_oop, x, p))
56-
cons_eqs = ModelingToolkit.equations(cons_sys)
57-
cons_exprs = map(cons_eqs) do cons_eq
58-
e = symbolify(ModelingToolkit.Symbolics.toexpr(cons_eq))
59-
rep_pars_vals!(e, pairs_arr)
60-
return Expr(:call, :(==), e.args[3], :0)
61-
end
25+
hv = function (H, θ, v, args...)
26+
res = adtype.obj_sparse ? (eltype(θ)).(f.hess_prototype) : ArrayInterfaceCore.zeromatrix(θ)
27+
hess(res, θ, args...)
28+
H .= res * v
6229
end
6330

64-
if f.cons !== nothing && f.cons_j === nothing
65-
jac_oop, jac_iip = ModelingToolkit.generate_jacobian(cons_sys,
66-
expression = Val{false},
67-
sparse = adtype.cons_sparse)
68-
cons_j = function (J, θ)
69-
jac_iip(J, θ, p)
70-
end
71-
else
72-
cons_j = (J, θ) -> f.cons_j(J, θ, p)
73-
end
74-
75-
if f.cons !== nothing && f.cons_h === nothing
76-
cons_hess_oop, cons_hess_iip = ModelingToolkit.generate_hessian(cons_sys,
77-
expression = Val{
78-
false
79-
},
80-
sparse = adtype.cons_sparse)
81-
cons_h = function (res, θ)
82-
cons_hess_iip(res, θ, p)
83-
end
84-
else
85-
cons_h = (res, θ) -> f.cons_h(res, θ, p)
86-
end
31+
cons = (res, θ) -> f.cons(res, θ, p)
8732

88-
if adtype.obj_sparse
89-
_hess_prototype = ModelingToolkit.hessian_sparsity(sys)
90-
hess_prototype = convert.(eltype(x), _hess_prototype)
91-
end
92-
93-
if adtype.cons_sparse
94-
_cons_jac_prototype = ModelingToolkit.jacobian_sparsity(cons_sys)
95-
cons_jac_prototype = convert.(eltype(x), _cons_jac_prototype)
96-
_cons_hess_prototype = ModelingToolkit.hessian_sparsity(cons_sys)
97-
cons_hess_prototype = [convert.(eltype(x), _cons_hess_prototype[i])
98-
for i in 1:num_cons]
99-
end
33+
cons_j = (J, θ) -> f.cons_j(J, θ, p)
10034

35+
cons_h = (res, θ) -> f.cons_h(res, θ, p)
10136
return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
10237
cons = cons, cons_j = cons_j, cons_h = cons_h,
103-
hess_prototype = hess_prototype,
104-
cons_jac_prototype = cons_jac_prototype,
105-
cons_hess_prototype = cons_hess_prototype,
106-
expr = expr, cons_expr = cons_exprs)
38+
hess_prototype = f.hess_prototype,
39+
cons_jac_prototype = f.cons_jac_prototype,
40+
cons_hess_prototype = f.cons_hess_prototype,
41+
expr = symbolify(f.expr),
42+
cons_expr = symbolify.(f.cons_expr))
10743
end

0 commit comments

Comments
 (0)