Skip to content

Commit 042d75f

Browse files
authored
Operators that combine first and second order (#300)
* Update fallback for mutating second-order operators * Enable operators that combine first and second order * Docstring * Scenarios * Sparse scenarios * Scenario intact * Fix benchmarks * Fix ref * Logging * Useless logging * Typo * Typo * FastDiff fix * Format
1 parent d5cb08b commit 042d75f

File tree

13 files changed

+378
-128
lines changed

13 files changed

+378
-128
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.5.3"
4+
version = "0.5.4"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ SecondOrder
7474
prepare_second_derivative
7575
second_derivative
7676
second_derivative!
77+
value_derivative_and_second_derivative
78+
value_derivative_and_second_derivative!
7779
```
7880

7981
### Hessian-vector product

DifferentiationInterface/docs/src/operators.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ These operators are computed using the input `x` and a "seed" `v`, which lives e
4545

4646
Several variants of each operator are defined.
4747

48-
| out-of-place | in-place | out-of-place + primal | in-place + primal |
49-
| :-------------------------- | :--------------------------- | :------------------------------ | :------------------------------- |
50-
| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) |
51-
| [`second_derivative`](@ref) | [`second_derivative!`](@ref) | NA | NA |
52-
| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
53-
| [`hessian`](@ref) | [`hessian!`](@ref) | NA | NA |
54-
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
55-
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
56-
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
57-
| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA |
48+
| out-of-place | in-place | out-of-place + primal | in-place + primal |
49+
| :-------------------------- | :--------------------------- | :----------------------------------------------- | :----------------------------------------------- |
50+
| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) |
51+
| [`second_derivative`](@ref) | [`second_derivative!`](@ref) | [`value_derivative_and_second_derivative`](@ref) | [`value_derivative_and_second_derivative!`](@ref) |
52+
| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
53+
| [`hessian`](@ref) | [`hessian!`](@ref) | NA | NA |
54+
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
55+
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
56+
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
57+
| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA |
5858

5959
## Mutation and signatures
6060

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,13 @@ end
307307

308308
## Second derivative
309309

310-
struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E2} <:
310+
struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E1!,E2,E2!} <:
311311
SecondDerivativeExtras
312312
y_prototype::Y
313-
der2_exe::E1
314-
der2_exe!::E2
313+
der_exe::E1
314+
der_exe!::E1!
315+
der2_exe::E2
316+
der2_exe!::E2!
315317
end
316318

317319
function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x)
@@ -321,11 +323,18 @@ function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x)
321323

322324
x_vec_var = monovec(x_var)
323325
y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var)
326+
327+
der_vec_var = derivative(y_vec_var, x_var)
324328
der2_vec_var = derivative(y_vec_var, x_var, x_var)
329+
330+
der_exe = make_function(der_vec_var, x_vec_var; in_place=false)
331+
der_exe! = make_function(der_vec_var, x_vec_var; in_place=true)
332+
325333
der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false)
326334
der2_exe! = make_function(der2_vec_var, x_vec_var; in_place=true)
335+
327336
return FastDifferentiationAllocatingSecondDerivativeExtras(
328-
y_prototype, der2_exe, der2_exe!
337+
y_prototype, der_exe, der_exe!, der2_exe, der2_exe!
329338
)
330339
end
331340

@@ -353,6 +362,38 @@ function DI.second_derivative!(
353362
return der2
354363
end
355364

365+
function DI.value_derivative_and_second_derivative(
366+
f,
367+
::AutoFastDifferentiation,
368+
x,
369+
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
370+
)
371+
y = f(x)
372+
if extras.y_prototype isa Number
373+
der = only(extras.der_exe(monovec(x)))
374+
der2 = only(extras.der2_exe(monovec(x)))
375+
return y, der, der2
376+
else
377+
der = reshape(extras.der_exe(monovec(x)), size(extras.y_prototype))
378+
der2 = reshape(extras.der2_exe(monovec(x)), size(extras.y_prototype))
379+
return y, der, der2
380+
end
381+
end
382+
383+
function DI.value_derivative_and_second_derivative!(
384+
f,
385+
der,
386+
der2,
387+
backend::AutoFastDifferentiation,
388+
x,
389+
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
390+
)
391+
y = f(x)
392+
extras.der_exe!(vec(der), monovec(x))
393+
extras.der2_exe!(vec(der2), monovec(x))
394+
return y, der, der2
395+
end
396+
356397
## HVP
357398

358399
struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ export gradient!, gradient
9696
export jacobian!, jacobian
9797

9898
export second_derivative!, second_derivative
99+
export value_derivative_and_second_derivative, value_derivative_and_second_derivative!
99100
export hvp!, hvp
100101
export hessian!, hessian
101102

DifferentiationInterface/src/second_order/second_derivative.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ Compute the second derivative of the function `f` at point `x`, overwriting `der
2424
"""
2525
function second_derivative! end
2626

27+
"""
28+
value_derivative_and_second_derivative(f, backend, x, [extras]) -> (y, der, der2)
29+
30+
Compute the value, first derivative and second derivative of the function `f` at point `x`.
31+
"""
32+
function value_derivative_and_second_derivative end
33+
34+
"""
35+
value_derivative_and_second_derivative!(f, der, der2, backend, x, [extras]) -> (y, der, der2)
36+
37+
Compute the value, first derivative and second derivative of the function `f` at point `x`, overwriting `der` and `der2`.
38+
"""
39+
function value_derivative_and_second_derivative! end
40+
2741
## Preparation
2842

2943
"""
@@ -74,6 +88,31 @@ function second_derivative(
7488
return derivative(inner_derivative_closure, outer(backend), x, outer_derivative_extras)
7589
end
7690

91+
function value_derivative_and_second_derivative(
92+
f::F,
93+
backend::AbstractADType,
94+
x,
95+
extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x),
96+
) where {F}
97+
return value_derivative_and_second_derivative(
98+
f, SecondOrder(backend, backend), x, extras
99+
)
100+
end
101+
102+
function value_derivative_and_second_derivative(
103+
f::F,
104+
backend::SecondOrder,
105+
x,
106+
extras::ClosureSecondDerivativeExtras=prepare_second_derivative(f, backend, x),
107+
) where {F}
108+
@compat (; inner_derivative_closure, outer_derivative_extras) = extras
109+
y = f(x)
110+
der, der2 = value_and_derivative(
111+
inner_derivative_closure, outer(backend), x, outer_derivative_extras
112+
)
113+
return y, der, der2
114+
end
115+
77116
function second_derivative!(
78117
f::F,
79118
der2,
@@ -96,3 +135,32 @@ function second_derivative!(
96135
inner_derivative_closure, der2, outer(backend), x, outer_derivative_extras
97136
)
98137
end
138+
139+
function value_derivative_and_second_derivative!(
140+
f::F,
141+
der,
142+
der2,
143+
backend::AbstractADType,
144+
x,
145+
extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x),
146+
) where {F}
147+
return value_derivative_and_second_derivative!(
148+
f, der, der2, SecondOrder(backend, backend), x, extras
149+
)
150+
end
151+
152+
function value_derivative_and_second_derivative!(
153+
f::F,
154+
der,
155+
der2,
156+
backend::SecondOrder,
157+
x,
158+
extras::SecondDerivativeExtras=prepare_second_derivative(f, backend, x),
159+
) where {F}
160+
@compat (; inner_derivative_closure, outer_derivative_extras) = extras
161+
y = f(x)
162+
new_der, _ = value_and_derivative!(
163+
inner_derivative_closure, der2, outer(backend), x, outer_derivative_extras
164+
)
165+
return y, copyto!(der, new_der), der2
166+
end

DifferentiationInterfaceTest/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterfaceTest/src/scenarios/default.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ function num_to_num_scenarios_onearg(x::Number)
2525
PullbackScenario(num_to_num; x=x, ref=num_to_num_pullback, place=:outofplace),
2626
DerivativeScenario(num_to_num; x=x, ref=num_to_num_derivative, place=:outofplace),
2727
SecondDerivativeScenario(
28-
num_to_num; x=x, ref=num_to_num_second_derivative, place=:outofplace
28+
num_to_num;
29+
x=x,
30+
ref=num_to_num_second_derivative,
31+
first_order_ref=num_to_num_derivative,
32+
place=:outofplace,
2933
),
3034
]
3135
end
@@ -79,7 +83,11 @@ function num_to_arr_scenarios_onearg(x::Number, a::AbstractArray)
7983
_num_to_arr(a); x=x, ref=_num_to_arr_derivative(a), place=place
8084
),
8185
SecondDerivativeScenario(
82-
_num_to_arr(a); x=x, ref=_num_to_arr_second_derivative(a), place=place
86+
_num_to_arr(a);
87+
x=x,
88+
ref=_num_to_arr_second_derivative(a),
89+
first_order_ref=_num_to_arr_derivative(a),
90+
place=place,
8391
),
8492
],
8593
)
@@ -205,8 +213,20 @@ function arr_to_num_scenarios_onearg(x::AbstractArray; linalg=true)
205213
[
206214
PullbackScenario(arr_to_num; x=x, ref=arr_to_num_pullback, place=place),
207215
GradientScenario(arr_to_num; x=x, ref=arr_to_num_gradient, place=place),
208-
HVPScenario(arr_to_num; x=x, ref=arr_to_num_hvp, place=place),
209-
HessianScenario(arr_to_num; x=x, ref=arr_to_num_hessian, place=place),
216+
HVPScenario(
217+
arr_to_num;
218+
x=x,
219+
ref=arr_to_num_hvp,
220+
first_order_ref=arr_to_num_gradient,
221+
place=place,
222+
),
223+
HessianScenario(
224+
arr_to_num;
225+
x=x,
226+
ref=arr_to_num_hessian,
227+
first_order_ref=arr_to_num_gradient,
228+
place=place,
229+
),
210230
],
211231
)
212232
end

0 commit comments

Comments
 (0)