Skip to content

Commit 33e0000

Browse files
Merge pull request #585 from SciML/opfadups
[WIP] Enzyme and sparse updates
2 parents c3ba7aa + c1ccddf commit 33e0000

11 files changed

+743
-492
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.DS_Store
22
/Manifest.toml
33
/dev/
4-
/docs/build/
4+
/docs/build/
5+
.vscode

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1717
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1818
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1919
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
20-
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2120
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2221

2322
[weakdeps]
@@ -27,6 +26,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2726
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
2827
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2928
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
29+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3030
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3131
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3232

@@ -36,8 +36,7 @@ OptimizationFiniteDiffExt = "FiniteDiff"
3636
OptimizationForwardDiffExt = "ForwardDiff"
3737
OptimizationMTKExt = "ModelingToolkit"
3838
OptimizationReverseDiffExt = "ReverseDiff"
39-
OptimizationSparseFiniteDiffExt = ["SparseDiffTools", "FiniteDiff"]
40-
OptimizationSparseForwardDiffExt = ["SparseDiffTools", "ForwardDiff"]
39+
OptimizationSparseDiffExt = ["SparseDiffTools", "Symbolics", "ReverseDiff"]
4140
OptimizationTrackerExt = "Tracker"
4241
OptimizationZygoteExt = "Zygote"
4342

@@ -58,3 +57,4 @@ julia = "1.6"
5857

5958
[extras]
6059
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
60+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

ext/OptimizationForwardDiffExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},
107107

108108
if f.hess === nothing
109109
hesscfg = ForwardDiff.HessianConfig(_f, cache.u0, ForwardDiff.Chunk{chunksize}())
110-
hess = (res, θ, args...) -> ForwardDiff.hessian!(res, x -> _f(x, args...), θ,
111-
hesscfg, Val{false}())
110+
hess = (res, θ, args...) -> (ForwardDiff.hessian!(res, x -> _f(x, args...), θ,
111+
hesscfg, Val{false}()))
112112
else
113113
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
114114
end

ext/OptimizationReverseDiffExt.jl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module OptimizationReverseDiffExt
33
import Optimization
44
import Optimization.SciMLBase: OptimizationFunction
55
import Optimization.ADTypes: AutoReverseDiff
6+
# using SparseDiffTools, Symbolics
67
isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) :
78
(using ..ReverseDiff, ..ReverseDiff.ForwardDiff)
89

@@ -20,9 +21,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
2021

2122
if f.hess === nothing
2223
hess = function (res, θ, args...)
23-
res .= ForwardDiff.jacobian(θ) do θ
24-
ReverseDiff.gradient(x -> _f(x, args...), θ)
25-
end
24+
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
2625
end
2726
else
2827
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
@@ -59,9 +58,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
5958
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
6059
cons_h = function (res, θ)
6160
for i in 1:num_cons
62-
res[i] .= ForwardDiff.jacobian(θ) do θ
63-
ReverseDiff.gradient(fncs[i], θ)
64-
end
61+
ReverseDiff.hessian!(res[i], fncs[i], θ)
6562
end
6663
end
6764
else
@@ -86,17 +83,14 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
8683
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))
8784

8885
if f.grad === nothing
89-
cfg = ReverseDiff.GradientConfig(cache.u0)
9086
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
9187
else
9288
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
9389
end
9490

9591
if f.hess === nothing
9692
hess = function (res, θ, args...)
97-
res .= ForwardDiff.jacobian(θ) do θ
98-
ReverseDiff.gradient(x -> _f(x, args...), θ)
99-
end
93+
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
10094
end
10195
else
10296
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
@@ -133,9 +127,7 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
133127
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
134128
cons_h = function (res, θ)
135129
for i in 1:num_cons
136-
res[i] .= ForwardDiff.jacobian(θ) do θ
137-
ReverseDiff.gradient(fncs[i], θ)
138-
end
130+
ReverseDiff.hessian!(res[i], fncs[i], θ)
139131
end
140132
end
141133
else

0 commit comments

Comments
 (0)