Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@ using ForwardDiff
using Zygote: hessian_dual, hessian_reverse

@testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse]
function f(x, bias)
hessian = hess(x->sum(x.^3), x)
return hessian * x .+ bias
end

if hess == hessian_dual
@test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0]
@test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version
@test gradient(b->sum(f(rand(3),b)),rand(3))[1] ≈ [1, 1, 1]
else
@test_broken hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex
@test_broken hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0]
@test_broken gradient(b->sum(f(rand(3),b)),rand(3))[1] ≈ [1, 1, 1] # jacobian is not differentiable
end
@test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24])
@test hess(sin, pi/2) ≈ -1
Expand Down Expand Up @@ -133,7 +139,7 @@ using ForwardDiff
g3(x) = sum(abs2,ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g3,[2.0,3.2])
@test back(1.0)[1] == ForwardDiff.gradient(g3,[2.0,3.2])

# From https://github.com/FluxML/Zygote.jl/issues/1218
f1218(x::AbstractVector,y::AbstractVector) = sum(x)*sum(y)
gradf1218(x,y) = ForwardDiff.gradient(x->f1218(x,y), x)[1]
Expand Down