Skip to content

Commit 7b190ae

Browse files
authored
Merge pull request #129 from JuliaDiff/mz/dne
no more `⊢` for types for which rand_tangent gives DoesNotExist
2 parents 0833088 + 1c1a004 commit 7b190ae

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.5"
3+
version = "0.6.6"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Test.DefaultTestSet("test_scalar: relu at -0.5", Any[Test.DefaultTestSet("with t
117117
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
118118
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
119119
If this is not done the tangent will be automatically generated via [`ChainRulesTestUtils.rand_tangent`](@ref).
120-
A special case of this is that if you specify it as `x ⊢ nothing` then finite differencing will not be used on that input.
120+
A special case of this is that if you specify it as `x ⊢ DoesNotExist()` then finite differencing will not be used on that input.
121121
Similarly, by setting the `output_tangent` keyword argument, you can specify the tangent for the primal output.
122122

123123
This can be useful when the default provided [`ChainRulesTestUtils.rand_tangent`](@ref) doesn't produce the desired tangent for your type.

src/check_result.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
2929
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
3030
check_equal(x::Zero, y::Zero; kwargs...) = @test true
3131

32+
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
33+
check_equal(x::DoesNotExist, y::Nothing; kwargs...) = @test true
34+
check_equal(x::Nothing, y::DoesNotExist; kwargs...) = @test true
35+
3236
"""
3337
_can_pass_early(actual, expected; kwargs...)
3438
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;

src/testers.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
7676
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
7777
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
78-
Non-differentiable arguments, such as indices, should have `ẋ` set as `nothing`.
78+
Non-differentiable arguments, such as indices, should have `ẋ` set as `DoesNotExist()`.
7979
8080
# Keyword Arguments
8181
- `output_tangent` tangent to test accumulation of derivatives against
@@ -114,7 +114,16 @@ function test_frule(
114114
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
115115
check_equal(Ω_ad, Ω; isapprox_kwargs...)
116116

117-
ẋs_is_ignored = ẋs .== nothing
117+
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
118+
ẋs_is_ignored = isa.(ẋs, Union{Nothing, DoesNotExist})
119+
if any(ẋs .== nothing)
120+
Base.depwarn(
121+
"test_frule(f, k ⊢ nothing) is deprecated, use " *
122+
"test_frule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks",
123+
:test_frule
124+
)
125+
end
126+
118127
# Correctness testing via finite differencing.
119128
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
120129
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)
@@ -134,7 +143,7 @@ end
134143
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
135144
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
136145
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
137-
Non-differentiable arguments, such as indices, should have `x̄` set as `nothing`.
146+
Non-differentiable arguments, such as indices, should have `x̄` set as `DoesNotExist()`.
138147
139148
# Keyword Arguments
140149
- `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
@@ -182,10 +191,19 @@ function test_rrule(
182191
@test ∂self === NO_FIELDS # No internal fields
183192

184193
# Correctness testing via finite differencing.
185-
x̄s_is_dne = accumulated_x̄ .== nothing
194+
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
195+
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, DoesNotExist})
196+
if any(accumulated_x̄ .== nothing)
197+
Base.depwarn(
198+
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
199+
"test_rrule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks",
200+
:test_rrule
201+
)
202+
end
203+
186204
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
187205
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
188-
if accumulated_x̄ === nothing # then we marked this argument as not differentiable
206+
if accumulated_x̄ isa Union{Nothing, DoesNotExist} # then we marked this argument as not differentiable # TODO remove once #113
189207
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
190208
@test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
191209
else

0 commit comments

Comments
 (0)