Skip to content

Commit 1167183

Browse files
Merge pull request #590 from SciML/reversediff_tape_compilaton
Setup tape compilation for ReverseDiff
2 parents c3f0da8 + 0f6c2f1 commit 1167183

File tree

7 files changed

+346
-63
lines changed

7 files changed

+346
-63
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Optimization.LinearAlgebra: I
66
import Optimization.ADTypes: AutoEnzyme
77
isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme)
88

9-
@inline function firstapply(f, θ, p, args...)
9+
@inline function firstapply(f::F, θ, p, args...) where F
1010
res = f(θ, p, args...)
1111
if isa(res, AbstractFloat)
1212
res
@@ -20,15 +20,17 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
2020
num_cons = 0)
2121

2222
if f.grad === nothing
23-
function grad(res, θ, args...)
24-
res .= zero(eltype(res))
25-
Enzyme.autodiff(Enzyme.Reverse,
26-
Const(firstapply),
27-
Active,
28-
Const(f.f),
29-
Enzyme.Duplicated(θ, res),
30-
Const(p),
31-
args...)
23+
grad = let
24+
function (res, θ, args...)
25+
res .= zero(eltype(res))
26+
Enzyme.autodiff(Enzyme.Reverse,
27+
Const(firstapply),
28+
Active,
29+
Const(f.f),
30+
Enzyme.Duplicated(θ, res),
31+
Const(p),
32+
args...)
33+
end
3234
end
3335
else
3436
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)

ext/OptimizationReverseDiffExt.jl

Lines changed: 130 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,60 @@ import Optimization.ADTypes: AutoReverseDiff
77
isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) :
88
(using ..ReverseDiff, ..ReverseDiff.ForwardDiff)
99

10+
struct OptimizationReverseDiffTag end
11+
1012
function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
1113
p = SciMLBase.NullParameters(),
1214
num_cons = 0)
1315
_f = (θ, args...) -> first(f.f(θ, p, args...))
1416

1517
if f.grad === nothing
16-
cfg = ReverseDiff.GradientConfig(x)
17-
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
18+
if adtype.compile
19+
_tape = ReverseDiff.GradientTape(_f, x)
20+
tape = ReverseDiff.compile(_tape)
21+
grad = function (res, θ, args...)
22+
ReverseDiff.gradient!(res, tape, θ)
23+
end
24+
else
25+
cfg = ReverseDiff.GradientConfig(x)
26+
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
27+
end
1828
else
1929
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
2030
end
2131

2232
if f.hess === nothing
23-
hess = function (res, θ, args...)
24-
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
33+
if adtype.compile
34+
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(x))
35+
xdual = ForwardDiff.Dual{typeof(T),eltype(x),length(x)}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), length(x))...,))))
36+
h_tape = ReverseDiff.GradientTape(_f, xdual)
37+
htape = ReverseDiff.compile(h_tape)
38+
function g(θ)
39+
res1 = zeros(eltype(θ), length(θ))
40+
ReverseDiff.gradient!(res1, htape, θ)
41+
end
42+
jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk(x), T)
43+
hess = function (res, θ, args...)
44+
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())
45+
end
46+
else
47+
hess = function (res, θ, args...)
48+
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
49+
end
2550
end
2651
else
2752
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
2853
end
2954

3055
if f.hv === nothing
3156
hv = function (H, θ, v, args...)
32-
= ForwardDiff.Dual.(θ, v)
33-
res = similar(_θ)
34-
grad(res, _θ, args...)
35-
H .= getindex.(ForwardDiff.partials.(res), 1)
57+
# _θ = ForwardDiff.Dual.(θ, v)
58+
# res = similar(_θ)
59+
# grad(res, _θ, args...)
60+
# H .= getindex.(ForwardDiff.partials.(res), 1)
61+
res = zeros(length(θ), length(θ))
62+
hess(res, θ, args...)
63+
H .= res * v
3664
end
3765
else
3866
hv = f.hv
@@ -46,19 +74,43 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
4674
end
4775

4876
if cons !== nothing && f.cons_j === nothing
49-
cjconfig = ReverseDiff.JacobianConfig(x)
50-
cons_j = function (J, θ)
51-
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
77+
if adtype.compile
78+
_jac_tape = ReverseDiff.JacobianTape(cons_oop, x)
79+
jac_tape = ReverseDiff.compile(_jac_tape)
80+
cons_j = function (J, θ)
81+
ReverseDiff.jacobian!(J, jac_tape, θ)
82+
end
83+
else
84+
cjconfig = ReverseDiff.JacobianConfig(x)
85+
cons_j = function (J, θ)
86+
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
87+
end
5288
end
5389
else
5490
cons_j = (J, θ) -> f.cons_j(J, θ, p)
5591
end
5692

5793
if cons !== nothing && f.cons_h === nothing
5894
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
59-
cons_h = function (res, θ)
60-
for i in 1:num_cons
61-
ReverseDiff.hessian!(res[i], fncs[i], θ)
95+
if adtype.compile
96+
consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual))
97+
conshtapes = ReverseDiff.compile.(consh_tapes)
98+
function grad_cons(θ, htape)
99+
res1 = zeros(eltype(θ), length(θ))
100+
ReverseDiff.gradient!(res1, htape, θ)
101+
end
102+
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
103+
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], x, ForwardDiff.Chunk(x), T) for i in 1:num_cons]
104+
cons_h = function (res, θ)
105+
for i in 1:num_cons
106+
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
107+
end
108+
end
109+
else
110+
cons_h = function (res, θ)
111+
for i in 1:num_cons
112+
ReverseDiff.hessian!(res[i], fncs[i], θ)
113+
end
62114
end
63115
end
64116
else
@@ -83,25 +135,52 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
83135
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))
84136

85137
if f.grad === nothing
86-
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
138+
if adtype.compile
139+
_tape = ReverseDiff.GradientTape(_f, cache.u0)
140+
tape = ReverseDiff.compile(_tape)
141+
grad = function (res, θ, args...)
142+
ReverseDiff.gradient!(res, tape, θ)
143+
end
144+
else
145+
cfg = ReverseDiff.GradientConfig(cache.u0)
146+
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)
147+
end
87148
else
88149
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
89150
end
90151

91152
if f.hess === nothing
92-
hess = function (res, θ, args...)
93-
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
153+
if adtype.compile
154+
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(cache.u0))
155+
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),length(cache.u0)}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), length(cache.u0))...,))))
156+
h_tape = ReverseDiff.GradientTape(_f, xdual)
157+
htape = ReverseDiff.compile(h_tape)
158+
function g(θ)
159+
res1 = zeros(eltype(θ), length(θ))
160+
ReverseDiff.gradient!(res1, htape, θ)
161+
end
162+
jaccfg = ForwardDiff.JacobianConfig(g, cache.u0, ForwardDiff.Chunk(cache.u0), T)
163+
hess = function (res, θ, args...)
164+
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())
165+
end
166+
else
167+
hess = function (res, θ, args...)
168+
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
169+
end
94170
end
95171
else
96172
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
97173
end
98174

99175
if f.hv === nothing
100176
hv = function (H, θ, v, args...)
101-
= ForwardDiff.Dual.(θ, v)
102-
res = similar(_θ)
103-
grad(res, _θ, args...)
104-
H .= getindex.(ForwardDiff.partials.(res), 1)
177+
# _θ = ForwardDiff.Dual.(θ, v)
178+
# res = similar(_θ)
179+
# grad(res, θ, args...)
180+
# H .= getindex.(ForwardDiff.partials.(res), 1)
181+
res = zeros(length(θ), length(θ))
182+
hess(res, θ, args...)
183+
H .= res * v
105184
end
106185
else
107186
hv = f.hv
@@ -115,19 +194,43 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
115194
end
116195

117196
if cons !== nothing && f.cons_j === nothing
118-
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
119-
cons_j = function (J, θ)
120-
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
197+
if adtype.compile
198+
_jac_tape = ReverseDiff.JacobianTape(cons_oop, cache.u0)
199+
jac_tape = ReverseDiff.compile(_jac_tape)
200+
cons_j = function (J, θ)
201+
ReverseDiff.jacobian!(J, jac_tape, θ)
202+
end
203+
else
204+
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
205+
cons_j = function (J, θ)
206+
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
207+
end
121208
end
122209
else
123210
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
124211
end
125212

126213
if cons !== nothing && f.cons_h === nothing
127214
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
128-
cons_h = function (res, θ)
129-
for i in 1:num_cons
130-
ReverseDiff.hessian!(res[i], fncs[i], θ)
215+
if adtype.compile
216+
consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual))
217+
conshtapes = ReverseDiff.compile.(consh_tapes)
218+
function grad_cons(θ, htape)
219+
res1 = zeros(eltype(θ), length(θ))
220+
ReverseDiff.gradient!(res1, htape, θ)
221+
end
222+
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
223+
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], cache.u0, ForwardDiff.Chunk(cache.u0), T) for i in 1:num_cons]
224+
cons_h = function (res, θ)
225+
for i in 1:num_cons
226+
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
227+
end
228+
end
229+
else
230+
cons_h = function (res, θ)
231+
for i in 1:num_cons
232+
ReverseDiff.hessian!(res[i], fncs[i], θ)
233+
end
131234
end
132235
end
133236
else

0 commit comments

Comments
 (0)