Skip to content

Commit 2acd430

Browse files
authored
Better preparation for second order with single backend (#287)
1 parent 39dda67 commit 2acd430

File tree

4 files changed

+13
-14
lines changed

4 files changed

+13
-14
lines changed

DifferentiationInterface/docs/src/operators.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ Some backends natively support a set of second-order operators (typically only t
135135
In that case, it can be advantageous to use the backend on its own.
136136
If the operator is not supported natively, we will fall back on `SecondOrder(backend, backend)` (see below).
137137

138-
!!! warning
139-
Whenever the fallback on `SecondOrder(backend, backend)` occurs, the results of any preparation will be discarded.
140-
141138
### Combining backends
142139

143140
In general, you can use [`SecondOrder`](@ref) to combine different backends.

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ struct HVPHessianExtras{E<:HVPExtras} <: HessianExtras
4040
end
4141

4242
function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
43+
return prepare_hessian(f, SecondOrder(backend, backend), x)
44+
end
45+
46+
function prepare_hessian(f::F, backend::SecondOrder, x) where {F}
4347
v = basis(backend, x, first(CartesianIndices(x)))
4448
hvp_extras = prepare_hvp(f, backend, x, v)
4549
return HVPHessianExtras(hvp_extras)
@@ -50,9 +54,7 @@ end
5054
function hessian(
5155
f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x)
5256
) where {F}
53-
new_backend = SecondOrder(backend, backend)
54-
new_extras = prepare_hessian(f, new_backend, x)
55-
return hessian(f, new_backend, x, new_extras)
57+
return hessian(f, SecondOrder(backend, backend), x, extras)
5658
end
5759

5860
function hessian(

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ struct ReverseOverReverseHVPExtras{C,E} <: HVPExtras
7575
outer_pullback_extras::E
7676
end
7777

78-
prepare_hvp(f, ::AbstractADType, x, v) = NoHVPExtras()
78+
function prepare_hvp(f::F, backend::AbstractADType, x, v) where {F}
79+
return prepare_hvp(f, SecondOrder(backend, backend), x, v)
80+
end
7981

8082
function prepare_hvp(f::F, backend::SecondOrder, x, v) where {F}
8183
return prepare_hvp_aux(f, backend, x, v, hvp_mode(backend))
@@ -143,9 +145,7 @@ end
143145
function hvp(
144146
f::F, backend::AbstractADType, x, v, extras::HVPExtras=prepare_hvp(f, backend, x, v)
145147
) where {F}
146-
new_backend = SecondOrder(backend, backend)
147-
new_extras = prepare_hvp(f, new_backend, x, v)
148-
return hvp(f, new_backend, x, v, new_extras)
148+
return hvp(f, SecondOrder(backend, backend), x, v, extras)
149149
end
150150

151151
function hvp(

DifferentiationInterface/src/second_order/second_derivative.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ struct ClosureSecondDerivativeExtras{C,E} <: SecondDerivativeExtras
4040
outer_derivative_extras::E
4141
end
4242

43-
prepare_second_derivative(f::F, ::AbstractADType, x) where {F} = NoSecondDerivativeExtras()
43+
function prepare_second_derivative(f::F, backend::AbstractADType, x) where {F}
44+
return prepare_second_derivative(f, SecondOrder(backend, backend), x)
45+
end
4446

4547
function prepare_second_derivative(f::F, backend::SecondOrder, x) where {F}
4648
inner_backend = nested(inner(backend))
@@ -59,9 +61,7 @@ function second_derivative(
5961
x,
6062
extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x),
6163
) where {F}
62-
new_backend = SecondOrder(backend, backend)
63-
new_extras = prepare_second_derivative(f, new_backend, x)
64-
return second_derivative(f, new_backend, x, new_extras)
64+
return second_derivative(f, SecondOrder(backend, backend), x, extras)
6565
end
6666

6767
function second_derivative(

0 commit comments

Comments
 (0)