Skip to content

Commit 0107180

Browse files
authored
Merge pull request #139 from JuliaDiff/mz/dne
Better error message when user accidentally passes `Zero()` instead of `DoesNotExist()` in the pullback
2 parents 8b5e906 + e431c53 commit 0107180

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.6"
3+
version = "0.6.7"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/testers.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ function test_rrule(
205205
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
206206
if accumulated_x̄ isa Union{Nothing, DoesNotExist} # then we marked this argument as not differentiable # TODO remove once #113
207207
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
208+
x̄_ad isa Zero && error(
209+
"The pullback in the rrule for $f function should use DoesNotExist()" *
210+
" rather than Zero() for non-perturbable arguments."
211+
)
208212
@test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
209213
else
210214
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)

test/testers.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,4 +507,17 @@ end
507507
end
508508
test_rrule(rev_trouble, (3, 3.0) Composite{Tuple{Int, Float64}}(Zero(), 1.0))
509509
end
510+
511+
@testset "error message about incorrectly using Zero()" begin
512+
foo(a, i) = a[i]
513+
function ChainRulesCore.rrule(::typeof(foo), a, i)
514+
function foo_pullback(Δy)
515+
da = zeros(size(a))
516+
da[i] = Δy
517+
return NO_FIELDS, da, Zero()
518+
end
519+
return foo(a, i), foo_pullback
520+
end
521+
@test errors(() -> test_rrule(foo, [1.0, 2.0, 3.0], 2), "should use DoesNotExist()")
522+
end
510523
end

0 commit comments

Comments
 (0)