Skip to content

Commit 1edbc76

Browse files
authored
Merge branch 'master' into patch-2
2 parents 334164a + 53bd8ff commit 1edbc76

File tree

9 files changed

+104
-13
lines changed

9 files changed

+104
-13
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.5"
3+
version = "0.6.9"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.9.13"
14+
ChainRulesCore = "0.9.39"
1515
Compat = "3"
1616
FiniteDifferences = "0.12"
1717
julia = "1"

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Test.DefaultTestSet("test_scalar: relu at -0.5", Any[Test.DefaultTestSet("with t
117117
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
118118
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
119119
If this is not done the tangent will be automatically generated via [`ChainRulesTestUtils.rand_tangent`](@ref).
120-
A special case of this is that if you specify it as `x ⊢ nothing` then finite differencing will not be used on that input.
120+
A special case of this is that if you specify it as `x ⊢ DoesNotExist()` then finite differencing will not be used on that input.
121121
Similarly, by setting the `output_tangent` keyword argument, you can specify the tangent for the primal output.
122122

123123
This can be useful when the default provided [`ChainRulesTestUtils.rand_tangent`](@ref) doesn't produce the desired tangent for your type.

src/check_result.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@ check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
2929
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
3030
check_equal(x::Zero, y::Zero; kwargs...) = @test true
3131

32+
# remove once https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
33+
check_equal(x::DoesNotExist, y::Nothing; kwargs...) = @test true
34+
check_equal(x::Nothing, y::DoesNotExist; kwargs...) = @test true
35+
36+
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
37+
# not yet been implemented
38+
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
39+
check_equal(x::ChainRulesCore.NotImplemented, y; kwargs...) = @test_broken x == y
40+
check_equal(x, y::ChainRulesCore.NotImplemented; kwargs...) = @test_broken x == y
41+
# In this case we check for equality (messages etc. have to be equal)
42+
function check_equal(
43+
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented; kwargs...
44+
)
45+
return @test x == y
46+
end
47+
3248
"""
3349
_can_pass_early(actual, expected; kwargs...)
3450
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
@@ -133,3 +149,20 @@ function _check_add!!_behaviour(acc, val; kwargs...)
133149
acc_mutated = deepcopy(acc) # prevent this test changing others
134150
check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
135151
end
152+
153+
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
154+
# not yet been implemented
155+
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
156+
function _check_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...)
157+
return @test_broken acc_mutated == acc
158+
end
159+
function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...)
160+
return @test_broken acc_mutated == acc
161+
end
162+
# In this case we check for equality (messages etc. have to be equal)
163+
function _check_add!!_behaviour(
164+
acc_mutated::ChainRulesCore.NotImplemented, acc::ChainRulesCore.NotImplemented;
165+
kwargs...,
166+
)
167+
return @test acc_mutated == acc
168+
end

src/generate_tangent.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist()
5454

5555
rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)
5656

57+
# ref: https://github.com/JuliaLang/julia/issues/17629
58+
rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng))
59+
5760
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
5861

5962
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}

src/testers.jl

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
7676
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
7777
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
78-
Non-differentiable arguments, such as indices, should have `ẋ` set as `nothing`.
78+
Non-differentiable arguments, such as indices, should have `ẋ` set as `DoesNotExist()`.
7979
8080
# Keyword Arguments
8181
- `output_tangent` tangent to test accumulation of derivatives against
@@ -113,7 +113,16 @@ function test_frule(
113113
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
114114
check_equal(Ω_ad, Ω; isapprox_kwargs...)
115115

116-
ẋs_is_ignored = ẋs .== nothing
116+
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
117+
ẋs_is_ignored = isa.(ẋs, Union{Nothing, DoesNotExist})
118+
if any(ẋs .== nothing)
119+
Base.depwarn(
120+
"test_frule(f, k ⊢ nothing) is deprecated, use " *
121+
"test_frule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks",
122+
:test_frule
123+
)
124+
end
125+
117126
# Correctness testing via finite differencing.
118127
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
119128
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)
@@ -133,7 +142,7 @@ end
133142
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
134143
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
135144
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
136-
Non-differentiable arguments, such as indices, should have `x̄` set as `nothing`.
145+
Non-differentiable arguments, such as indices, should have `x̄` set as `DoesNotExist()`.
137146
138147
# Keyword Arguments
139148
- `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
@@ -181,11 +190,24 @@ function test_rrule(
181190
@test ∂self === NO_FIELDS # No internal fields
182191

183192
# Correctness testing via finite differencing.
184-
x̄s_is_dne = accumulated_x̄ .== nothing
193+
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
194+
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, DoesNotExist})
195+
if any(accumulated_x̄ .== nothing)
196+
Base.depwarn(
197+
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
198+
"test_rrule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks",
199+
:test_rrule
200+
)
201+
end
202+
185203
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
186204
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
187-
if accumulated_x̄ === nothing # then we marked this argument as not differentiable
205+
if accumulated_x̄ isa Union{Nothing, DoesNotExist} # then we marked this argument as not differentiable # TODO remove once #113
188206
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
207+
x̄_ad isa Zero && error(
208+
"The pullback in the rrule for $f function should use DoesNotExist()" *
209+
" rather than Zero() for non-perturbable arguments."
210+
)
189211
@test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
190212
else
191213
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)

test/check_result.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333

3434
@testset "check_equal" begin
3535

36-
@testset "possive cases" begin
36+
@testset "positive cases" begin
3737
check_equal(1.0, 1.0)
3838
check_equal(1.0 + im, 1.0 + im)
3939
check_equal(1.0, 1.0+1e-10) # isapprox _behaviour
@@ -46,6 +46,10 @@ end
4646

4747
check_equal(@thunk(10*0.1*[[1.0], [2.0]]), [[1.0], [2.0]])
4848

49+
check_equal(@not_implemented(""), rand(3))
50+
check_equal(rand(3), @not_implemented(""))
51+
check_equal(@not_implemented("a"), @not_implemented("a"))
52+
4953
check_equal(
5054
Composite{Tuple{Float64, Float64}}(1.0, 2.0),
5155
Composite{Tuple{Float64, Float64}}(1.0, 2.0)
@@ -92,6 +96,8 @@ end
9296
@test fails(()->check_equal([[1.0], [2.0]], [[1.1], [2.0]]))
9397

9498
@test fails(()->check_equal(@thunk(10*[[1.0], [2.0]]), [[1.0], [2.0]]))
99+
100+
@test fails(()->check_equal(@not_implemented("a"), @not_implemented("b")))
95101
end
96102
@testset "type negative" begin
97103
@test fails() do # these have different primals so should not be equal

test/generate_tangent.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ end
1717
(true, DoesNotExist),
1818
(4, DoesNotExist),
1919
(5.0, Float64),
20+
(big(5.0), BigFloat),
2021
(5.0 + 0.4im, Complex{Float64}),
2122
(randn(Float32, 3), Vector{Float32}),
2223
(randn(Complex{Float64}, 2), Vector{Complex{Float64}}),

test/meta_testing_tools.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ end
121121
@testset "Single Test" begin
122122
fails = nonpassing_results(()->@test false)
123123
@test length(fails) === 1
124-
@test fails[1].orig_expr == false
124+
# Julia 1.6 return a `String`, not an `Expr`.
125+
# Always calling `string` on it gives gives consistency regardless of version.
126+
# https://github.com/JuliaLang/julia/pull/37809
127+
@test string(fails[1].orig_expr) == string(false)
125128
end
126129

127130
@testset "Single Testset" begin
@@ -132,8 +135,12 @@ end
132135
end
133136
end
134137
@test length(fails) === 2
135-
@test fails[1].orig_expr == :(false==true)
136-
@test fails[2].orig_expr == :(true==false)
138+
139+
# Julia 1.6 return a `String`, not an `Expr`.
140+
# Always calling `string` on it gives gives consistency regardless of version.
141+
# https://github.com/JuliaLang/julia/pull/37809
142+
@test string(fails[1].orig_expr) == string(:(false == true))
143+
@test string(fails[2].orig_expr) == string(:(true == false))
137144
end
138145

139146

test/testers.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ f_noninferrable_pullback(x) = x
1414
f_noninferrable_thunk(x, y) = x + y
1515
f_inferrable_pullback_only(x) = x > 0 ? Float64(x) : Float32(x)
1616

17-
1817
function finplace!(x; y = [1])
1918
y[1] = 2
2019
x .*= y[1]
@@ -507,4 +506,24 @@ end
507506
end
508507
test_rrule(rev_trouble, (3, 3.0) Composite{Tuple{Int, Float64}}(Zero(), 1.0))
509508
end
509+
510+
@testset "error message about incorrectly using Zero()" begin
511+
foo(a, i) = a[i]
512+
function ChainRulesCore.rrule(::typeof(foo), a, i)
513+
function foo_pullback(Δy)
514+
da = zeros(size(a))
515+
da[i] = Δy
516+
return NO_FIELDS, da, Zero()
517+
end
518+
return foo(a, i), foo_pullback
519+
end
520+
@test errors(() -> test_rrule(foo, [1.0, 2.0, 3.0], 2), "should use DoesNotExist()")
521+
end
522+
523+
@testset "NotImplemented" begin
524+
f_notimplemented(x, y) = (x + y, x - y)
525+
@scalar_rule f_notimplemented(x, y) (@not_implemented(""), 1) (1, -1)
526+
test_frule(f_notimplemented, randn(), randn())
527+
test_rrule(f_notimplemented, randn(), randn())
528+
end
510529
end

0 commit comments

Comments
 (0)