Skip to content

Commit a989d96

Browse files
Merge pull request #241 from SciML/Vaibhavdixit02-patch-1
Add `obj_sparse` and `cons_sparse` fields to `AutoModelingToolkit` and pass to derivative functions
2 parents 2bd1134 + 96390c9 commit a989d96

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

src/function/mtk.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
struct AutoModelingToolkit <: AbstractADType end
1+
struct AutoModelingToolkit <: AbstractADType
2+
obj_sparse::Bool
3+
cons_sparse::Bool
4+
end
5+
6+
AutoModelingToolkit() = AutoModelingToolkit(false, false)
27

3-
function instantiate_function(f, x, ::AutoModelingToolkit, p, num_cons=0)
8+
function instantiate_function(f, x, ad::AutoModelingToolkit, p, num_cons=0)
49
p = isnothing(p) ? SciMLBase.NullParameters() : p
510
sys = ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p))
611

@@ -12,8 +17,8 @@ function instantiate_function(f, x, ::AutoModelingToolkit, p, num_cons=0)
1217
end
1318

1419
if f.hess === nothing
15-
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression=Val{false})
16-
hess(J, u) = (hess_iip(J, u, p); J)
20+
hess_oop, hess_iip = ModelingToolkit.generate_hessian(sys, expression=Val{false}, sparse = ad.obj_sparse)
21+
hess(H, u) = (hess_iip(H, u, p); H)
1722
else
1823
hess = f.hess
1924
end
@@ -36,17 +41,17 @@ function instantiate_function(f, x, ::AutoModelingToolkit, p, num_cons=0)
3641
end
3742

3843
if f.cons !== nothing && f.cons_j === nothing
44+
jac_oop, jac_iip = ModelingToolkit.generate_jacobian(cons_sys, expression=Val{false}, sparse=ad.cons_sparse)
3945
cons_j = function (J, θ)
40-
jac_oop, jac_iip = ModelingToolkit.generate_jacobian(cons_sys, expression=Val{false})
4146
jac_iip(J, θ, p)
4247
end
4348
else
4449
cons_j = f.cons_j
4550
end
4651

4752
if f.cons !== nothing && f.cons_h === nothing
53+
cons_hess_oop, cons_hess_iip = ModelingToolkit.generate_hessian(cons_sys, expression=Val{false}, sparse=ad.cons_sparse)
4854
cons_h = function (res, θ)
49-
cons_hess_oop, cons_hess_iip = ModelingToolkit.generate_hessian(cons_sys, expression=Val{false})
5055
cons_hess_iip(res, θ, p)
5156
end
5257
else

test/ADtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
5757
optprob.cons_h(H3, x0)
5858
@test H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
5959

60+
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoModelingToolkit(true, true), cons=con2_c)
61+
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoModelingToolkit(true, true), nothing, 2)
62+
using SparseArrays
63+
sH = sparse([1, 1, 2, 2], [1, 2, 1, 2], zeros(4))
64+
optprob.hess(sH, x0)
65+
@test sH == H2
66+
@test optprob.cons(x0) == [0.0, 0.0]
67+
sJ = sparse([1, 1, 2, 2], [1, 2, 1, 2], zeros(4))
68+
optprob.cons_j(sJ, [5.0, 3.0])
69+
@test all(isapprox(sJ, [10.0 6.0; -0.149013 -0.958924]; rtol=1e-3))
70+
sH3 = [sparse([1,2], [1, 2], zeros(2)), sparse([1, 1, 2], [1, 2, 1], zeros(3))]
71+
optprob.cons_h(sH3, x0)
72+
@test Array.(sH3) == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
73+
6074
optf = OptimizationFunction(rosenbrock, GalacticOptim.AutoForwardDiff())
6175
optprob = GalacticOptim.instantiate_function(optf, x0, GalacticOptim.AutoForwardDiff(), nothing)
6276
optprob.grad(G2, x0)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1414
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
15+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1617
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1718
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

0 commit comments

Comments
 (0)