Skip to content

Commit c6aaabe

Browse files
authored
Better hessian check, Const(f) in Enzyme (#284)
* Small changes * Hessian check * Correct hess_th
1 parent 2583e79 commit c6aaabe

File tree

6 files changed

+29
-13
lines changed

6 files changed

+29
-13
lines changed

DifferentiationInterface/docs/src/tutorial1.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@ using DifferentiationInterface
99
## Computing a gradient
1010

1111
A common use case of automatic differentiation (AD) is optimizing real-valued functions with first- or second-order methods.
12-
Let's define a simple objective and a random input vector
12+
Let's define a simple objective (the squared norm) and a random input vector
1313

1414
```@example tuto1
15-
f(x) = sum(abs2, x)
15+
function f(x::AbstractVector{T}) where {T}
16+
y = zero(T)
17+
for i in eachindex(x)
18+
y += abs2(x[i])
19+
end
20+
return y
21+
end
1622
1723
x = collect(1.0:5.0)
1824
```

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ function DI.value_and_pushforward(
66
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
77
)
88
dx_sametype = convert(typeof(x), dx)
9-
y, new_dy = autodiff(forward_mode(backend), f, Duplicated, Duplicated(x, dx_sametype))
9+
y, new_dy = autodiff(
10+
forward_mode(backend), Const(f), Duplicated, Duplicated(x, dx_sametype)
11+
)
1012
return y, new_dy
1113
end
1214

@@ -15,7 +17,9 @@ function DI.pushforward(
1517
)
1618
dx_sametype = convert(typeof(x), dx)
1719
new_dy = only(
18-
autodiff(forward_mode(backend), f, DuplicatedNoNeed, Duplicated(x, dx_sametype))
20+
autodiff(
21+
forward_mode(backend), Const(f), DuplicatedNoNeed, Duplicated(x, dx_sametype)
22+
),
1923
)
2024
return new_dy
2125
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function DI.value_and_pushforward(
99
dy_sametype = zero(y)
1010
autodiff(
1111
forward_mode(backend),
12-
f!,
12+
Const(f!),
1313
Const,
1414
Duplicated(y, dy_sametype),
1515
Duplicated(x, dx_sametype),

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ DI.prepare_pullback(f, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras()
77
function DI.value_and_pullback(
88
f, ::AutoReverseOrNothingEnzyme, x::Number, dy::Number, ::NoPullbackExtras
99
)
10-
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
10+
der, y = autodiff(ReverseWithPrimal, Const(f), Active, Active(x))
1111
new_dx = dy * only(der)
1212
return y, new_dx
1313
end
@@ -43,7 +43,7 @@ function DI.value_and_pullback!(
4343
f, dx, ::AutoReverseOrNothingEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras
4444
)
4545
dx_sametype = zero_sametype!(dx, x)
46-
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
46+
_, y = autodiff(ReverseWithPrimal, Const(f), Active, Duplicated(x, dx_sametype))
4747
dx_sametype .*= dy
4848
return y, copyto!(dx, dx_sametype)
4949
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ function DI.value_and_pullback(
77
)
88
dy_sametype = convert(typeof(y), copy(dy))
99
_, new_dx = only(
10-
autodiff(reverse_mode(backend), f!, Const, Duplicated(y, dy_sametype), Active(x))
10+
autodiff(
11+
reverse_mode(backend), Const(f!), Const, Duplicated(y, dy_sametype), Active(x)
12+
),
1113
)
1214
return y, new_dx
1315
end
@@ -19,7 +21,7 @@ function DI.value_and_pullback(
1921
dy_sametype = convert(typeof(y), copy(dy))
2022
autodiff(
2123
reverse_mode(backend),
22-
f!,
24+
Const(f!),
2325
Const,
2426
Duplicated(y, dy_sametype),
2527
Duplicated(x, dx_sametype),

DifferentiationInterface/src/utils/check.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Check whether `backend` supports differentiation of two-argument functions.
1616
"""
1717
check_twoarg(backend::AbstractADType) = Bool(twoarg_support(backend))
1818

19-
sqnorm(x::AbstractArray) = sum(abs2, x)
19+
hess_checker(x::AbstractArray) = abs2(x[1]) * abs2(x[2])
2020

2121
"""
2222
check_hessian(backend)
@@ -28,9 +28,13 @@ Check whether `backend` supports second order differentiation by trying to compu
2828
"""
2929
function check_hessian(backend::AbstractADType; verbose=true)
3030
try
31-
x = [1.0, 3.0]
32-
hess = hessian(sqnorm, backend, x)
33-
return isapprox(hess, [2.0 0.0; 0.0 2.0]; rtol=1e-3)
31+
x = [2.0, 3.0]
32+
hess = hessian(hess_checker, backend, x)
33+
hess_th = [
34+
2*abs2(x[2]) 4*x[1]*x[2]
35+
4*x[1]*x[2] 2*abs2(x[1])
36+
]
37+
return isapprox(hess, hess_th; rtol=1e-3)
3438
catch exception
3539
if verbose
3640
@warn "Backend $backend does not support hessian" exception

0 commit comments

Comments
 (0)