Skip to content

Commit dccd07e

Browse files
authored
fix: revert catch of nothing pullbacks with Zygote (#714)
1 parent 19803d1 commit dccd07e

File tree

3 files changed

+1
-34
lines changed

3 files changed

+1
-34
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.38"
4+
version = "0.6.39"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,6 @@ using Zygote:
1313
withgradient,
1414
withjacobian
1515

16-
struct ZygoteNothingError <: Exception
17-
f
18-
x
19-
contexts
20-
end
21-
22-
function Base.showerror(io::IO, e::ZygoteNothingError)
23-
(; f, x, contexts) = e
24-
sig = (typeof(x), map(typeof DI.unwrap, contexts)...)
25-
return print(
26-
io,
27-
"Zygote failed to differentiate function `$f` with argument types `$sig` (the pullback returned `nothing`).",
28-
)
29-
end
30-
31-
check_nothing(::Nothing, f, x, contexts) = throw(ZygoteNothingError(f, x, contexts))
32-
check_nothing(::Any, f, x, contexts) = nothing
33-
3416
DI.check_available(::AutoZygote) = true
3517
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()
3618

@@ -64,7 +46,6 @@ function DI.value_and_pullback(
6446
tx = map(ty) do dy
6547
first(pb(dy))
6648
end
67-
check_nothing(first(tx), f, x, contexts)
6849
return y, tx
6950
end
7051

@@ -80,7 +61,6 @@ function DI.value_and_pullback(
8061
tx = map(ty) do dy
8162
first(pb(dy))
8263
end
83-
check_nothing(first(tx), f, x, contexts)
8464
return copy(y), tx
8565
end
8666

@@ -96,7 +76,6 @@ function DI.pullback(
9676
tx = map(ty) do dy
9777
first(pb(dy))
9878
end
99-
check_nothing(first(tx), f, x, contexts)
10079
return tx
10180
end
10281

@@ -110,15 +89,13 @@ function DI.value_and_gradient(
11089
f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C}
11190
) where {C}
11291
(; val, grad) = withgradient(f, x, map(translate, contexts)...)
113-
check_nothing(first(grad), f, x, contexts)
11492
return val, first(grad)
11593
end
11694

11795
function DI.gradient(
11896
f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C}
11997
) where {C}
12098
grad = gradient(f, x, map(translate, contexts)...)
121-
check_nothing(first(grad), f, x, contexts)
12299
return first(grad)
123100
end
124101

@@ -147,15 +124,13 @@ function DI.value_and_jacobian(
147124
y = f(x, map(translate, contexts)...)
148125
# https://github.com/FluxML/Zygote.jl/issues/1506
149126
jac = jacobian(f, x, map(translate, contexts)...)
150-
check_nothing(first(jac), f, x, contexts)
151127
return y, first(jac)
152128
end
153129

154130
function DI.jacobian(
155131
f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Context,C}
156132
) where {C}
157133
jac = jacobian(f, x, map(translate, contexts)...)
158-
check_nothing(first(jac), f, x, contexts)
159134
return first(jac)
160135
end
161136

@@ -242,7 +217,6 @@ function DI.hessian(
242217
) where {C}
243218
fc = DI.with_contexts(f, contexts...)
244219
hess = hessian(fc, x)
245-
check_nothing(hess, f, x, contexts)
246220
return hess
247221
end
248222

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,3 @@ test_differentiation(
5656
logging=LOGGING,
5757
)
5858
end
59-
60-
## Errors
61-
62-
@testset "Errors" begin
63-
safe_log(x) = x > zero(x) ? log(x) : convert(typeof(x), NaN)
64-
@test_throws "Zygote failed to differentiate" derivative(safe_log, AutoZygote(), 0.0)
65-
end

0 commit comments

Comments
 (0)