Skip to content

Commit 3417ff5

Browse files
authored
perf: avoid double function call in ReverseDiff value_and_gradient (#729)
* perf: avoid double function call in ReverseDiff `value_and_gradient` * Fixes * Fix hessian * Separate completely * Replace DR.gradient(result) with grad
1 parent 98e6e5f commit 3417ff5

File tree

3 files changed

+55
-79
lines changed

3 files changed

+55
-79
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.40"
4+
version = "0.6.41"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl

Lines changed: 51 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,26 @@ end
8787
function DI.value_and_gradient!(
8888
f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff{compile}, x
8989
) where {compile}
90-
y = f(x) # TODO: ReverseDiff#251
91-
result = DiffResult(y, (grad,))
90+
result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251
9291
if compile
9392
result = gradient!(result, prep.tape, x)
9493
else
9594
result = gradient!(result, f, x, prep.config)
9695
end
97-
y = DR.value(result)
98-
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
99-
return y, grad
96+
return DR.value(result), grad # ReverseDiff#269
10097
end
10198

10299
function DI.value_and_gradient(
103-
f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff, x
104-
)
105-
grad = similar(x)
106-
return DI.value_and_gradient!(f, grad, prep, backend, x)
100+
f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff{compile}, x
101+
) where {compile}
102+
# GradientResult tries to mutate an SArray
103+
result = MutableDiffResult(zero(eltype(x)), (similar(x),))
104+
if compile
105+
result = gradient!(result, prep.tape, x)
106+
else
107+
result = gradient!(result, f, x, prep.config)
108+
end
109+
return DR.value(result), DR.gradient(result)
107110
end
108111

109112
function DI.gradient!(
@@ -144,23 +147,19 @@ function DI.value_and_gradient!(
144147
contexts::Vararg{DI.Context,C},
145148
) where {C}
146149
fc = DI.with_contexts(f, contexts...)
147-
y = fc(x) # TODO: ReverseDiff#251
148-
result = DiffResult(y, (grad,))
150+
result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251
149151
result = gradient!(result, fc, x, prep.config)
150-
y = DR.value(result)
151-
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
152-
return y, grad
152+
return DR.value(result), grad # ReverseDiff#269
153153
end
154154

155155
function DI.value_and_gradient(
156-
f,
157-
prep::ReverseDiffGradientPrep,
158-
backend::AutoReverseDiff,
159-
x,
160-
contexts::Vararg{DI.Context,C},
156+
f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}
161157
) where {C}
162-
grad = similar(x)
163-
return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
158+
fc = DI.with_contexts(f, contexts...)
159+
# GradientResult tries to mutate an SArray
160+
result = MutableDiffResult(zero(eltype(x)), (similar(x),))
161+
result = gradient!(result, fc, x, prep.config)
162+
return DR.value(result), DR.gradient(result)
164163
end
165164

166165
function DI.gradient!(
@@ -310,31 +309,23 @@ end
310309

311310
### Without contexts
312311

313-
@kwdef struct ReverseDiffHessianPrep{GC,HC,GT,HT} <: DI.HessianPrep
314-
gradient_config::GC
312+
@kwdef struct ReverseDiffHessianPrep{G<:ReverseDiffGradientPrep,HC,HT} <: DI.HessianPrep
313+
gradient_prep::G
315314
hessian_config::HC
316-
gradient_tape::GT
317315
hessian_tape::HT
318316
end
319317

320-
function DI.prepare_hessian(f, ::AutoReverseDiff{compile}, x) where {compile}
318+
function DI.prepare_hessian(f, backend::AutoReverseDiff{compile}, x) where {compile}
319+
gradient_prep = DI.prepare_gradient(f, backend, x)
321320
if compile
322-
gradient_tape = ReverseDiff.compile(GradientTape(f, x))
323321
hessian_tape = ReverseDiff.compile(HessianTape(f, x))
324322
return ReverseDiffHessianPrep(;
325-
gradient_config=nothing,
326-
hessian_config=nothing,
327-
gradient_tape=gradient_tape,
328-
hessian_tape=hessian_tape,
323+
gradient_prep, hessian_config=nothing, hessian_tape=hessian_tape
329324
)
330325
else
331-
gradient_config = GradientConfig(x)
332326
hessian_config = HessianConfig(x)
333327
return ReverseDiffHessianPrep(;
334-
gradient_config=gradient_config,
335-
hessian_config=hessian_config,
336-
gradient_tape=nothing,
337-
hessian_tape=nothing,
328+
gradient_prep, hessian_config=hessian_config, hessian_tape=nothing
338329
)
339330
end
340331
end
@@ -360,47 +351,32 @@ function DI.hessian(
360351
end
361352

362353
function DI.value_gradient_and_hessian!(
363-
f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x
354+
f, grad, hess, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x
364355
) where {compile}
365-
y = f(x) # TODO: ReverseDiff#251
366-
result = DiffResult(y, (grad, hess))
367-
if compile
368-
result = hessian!(result, prep.hessian_tape, x)
369-
grad = gradient!(grad, prep.gradient_tape, x) # TODO: ReverseDiff#251
370-
else
371-
result = hessian!(result, f, x) # TODO: add prep.hessian_config
372-
grad = gradient!(grad, f, x, prep.gradient_config) # TODO: ReverseDiff#251
373-
end
374-
# grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
375-
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
356+
y = f(x)
357+
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
358+
DI.hessian!(f, hess, prep, backend, x)
376359
return y, grad, hess
377360
end
378361

379362
function DI.value_gradient_and_hessian(
380-
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff{compile}, x
363+
f, prep::ReverseDiffHessianPrep, backend::AutoReverseDiff{compile}, x
381364
) where {compile}
382-
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
383-
result = DiffResult(y, (similar(x), similar(x, length(x), length(x))))
384-
if compile
385-
result = hessian!(result, prep.hessian_tape, x)
386-
else
387-
result = hessian!(result, f, x) # todo: add prep.hessian_config
388-
end
389-
return (y, DR.gradient(result), DR.hessian(result))
365+
y = f(x)
366+
grad = DI.gradient(f, prep.gradient_prep, backend, x)
367+
hess = DI.hessian(f, prep, backend, x)
368+
return y, grad, hess
390369
end
391370

392371
### With contexts
393372

394373
function DI.prepare_hessian(
395-
f, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}
374+
f, backend::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}
396375
) where {C}
397-
gradient_config = GradientConfig(x)
376+
gradient_prep = DI.prepare_gradient(f, backend, x, contexts...)
398377
hessian_config = HessianConfig(x)
399378
return ReverseDiffHessianPrep(;
400-
gradient_config=gradient_config,
401-
hessian_config=hessian_config,
402-
gradient_tape=nothing,
403-
hessian_tape=nothing,
379+
gradient_prep, hessian_config=hessian_config, hessian_tape=nothing
404380
)
405381
end
406382

@@ -428,27 +404,25 @@ function DI.value_gradient_and_hessian!(
428404
grad,
429405
hess,
430406
prep::ReverseDiffHessianPrep,
431-
::AutoReverseDiff,
407+
backend::AutoReverseDiff,
432408
x,
433409
contexts::Vararg{DI.Context,C},
434410
) where {C}
435-
fc = DI.with_contexts(f, contexts...)
436-
y = fc(x) # TODO: ReverseDiff#251
437-
result = DiffResult(y, (grad, hess))
438-
result = hessian!(result, fc, x) # TODO: add prep.hessian_config
439-
y = DR.value(result)
440-
# grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
441-
grad = gradient!(grad, fc, x, prep.gradient_config) # TODO: ReverseDiff#251
442-
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
411+
y = f(x, map(DI.unwrap, contexts)...)
412+
DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...)
413+
DI.hessian!(f, hess, prep, backend, x, contexts...)
443414
return y, grad, hess
444415
end
445416

446417
function DI.value_gradient_and_hessian(
447-
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{DI.Context,C}
418+
f,
419+
prep::ReverseDiffHessianPrep,
420+
backend::AutoReverseDiff,
421+
x,
422+
contexts::Vararg{DI.Context,C},
448423
) where {C}
449-
fc = DI.with_contexts(f, contexts...)
450-
y = fc(x) # TODO: ReverseDiff#251
451-
result = HessianResult(x)
452-
result = hessian!(result, fc, x) # TODO: add prep.hessian_config
453-
return (DR.value(result), DR.gradient(result), DR.hessian(result))
424+
y = f(x, map(DI.unwrap, contexts)...)
425+
grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...)
426+
hess = DI.hessian(f, prep, backend, x, contexts...)
427+
return y, grad, hess
454428
end

DifferentiationInterface/test/Back/ReverseDiff/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ test_differentiation(
2929
logging=LOGGING,
3030
);
3131

32-
test_differentiation(backends, static_scenarios(); logging=LOGGING);
32+
test_differentiation(
33+
backends, static_scenarios(; include_constantified=true); logging=LOGGING
34+
);
3335

3436
## Sparse
3537

0 commit comments

Comments
 (0)