Skip to content

Commit 53bd8ff

Browse files
devmotionoxinabox
andauthored
Enable testing of rules with @not_implemented (#140)
Co-authored-by: Lyndon White <[email protected]>
1 parent 0107180 commit 53bd8ff

File tree

4 files changed

+45
-4
lines changed

4 files changed

+45
-4
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.7"
3+
version = "0.6.8"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.9.13"
14+
ChainRulesCore = "0.9.39"
1515
Compat = "3"
1616
FiniteDifferences = "0.12"
1717
julia = "1"

src/check_result.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ check_equal(x::Zero, y::Zero; kwargs...) = @test true
3333
check_equal(x::DoesNotExist, y::Nothing; kwargs...) = @test true
3434
check_equal(x::Nothing, y::DoesNotExist; kwargs...) = @test true
3535

36+
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
37+
# not yet been implemented
38+
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
39+
check_equal(x::ChainRulesCore.NotImplemented, y; kwargs...) = @test_broken x == y
40+
check_equal(x, y::ChainRulesCore.NotImplemented; kwargs...) = @test_broken x == y
41+
# In this case we check for equality (messages etc. have to be equal)
42+
function check_equal(
43+
x::ChainRulesCore.NotImplemented, y::ChainRulesCore.NotImplemented; kwargs...
44+
)
45+
return @test x == y
46+
end
47+
3648
"""
3749
_can_pass_early(actual, expected; kwargs...)
3850
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
@@ -137,3 +149,20 @@ function _check_add!!_behaviour(acc, val; kwargs...)
137149
acc_mutated = deepcopy(acc) # prevent this test changing others
138150
check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
139151
end
152+
153+
# Checking equality with `NotImplemented` reports `@test_broken` since the derivative has intentionally
154+
# not yet been implemented
155+
# `@test_broken x == y` yields more descriptive messages than `@test_broken false`
156+
function _check_add!!_behaviour(acc_mutated, acc::ChainRulesCore.NotImplemented; kwargs...)
157+
return @test_broken acc_mutated == acc
158+
end
159+
function _check_add!!_behaviour(acc_mutated::ChainRulesCore.NotImplemented, acc; kwargs...)
160+
return @test_broken acc_mutated == acc
161+
end
162+
# In this case we check for equality (messages etc. have to be equal)
163+
function _check_add!!_behaviour(
164+
acc_mutated::ChainRulesCore.NotImplemented, acc::ChainRulesCore.NotImplemented;
165+
kwargs...,
166+
)
167+
return @test acc_mutated == acc
168+
end

test/check_result.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333

3434
@testset "check_equal" begin
3535

36-
@testset "possive cases" begin
36+
@testset "positive cases" begin
3737
check_equal(1.0, 1.0)
3838
check_equal(1.0 + im, 1.0 + im)
3939
check_equal(1.0, 1.0+1e-10) # isapprox _behaviour
@@ -46,6 +46,10 @@ end
4646

4747
check_equal(@thunk(10*0.1*[[1.0], [2.0]]), [[1.0], [2.0]])
4848

49+
check_equal(@not_implemented(""), rand(3))
50+
check_equal(rand(3), @not_implemented(""))
51+
check_equal(@not_implemented("a"), @not_implemented("a"))
52+
4953
check_equal(
5054
Composite{Tuple{Float64, Float64}}(1.0, 2.0),
5155
Composite{Tuple{Float64, Float64}}(1.0, 2.0)
@@ -92,6 +96,8 @@ end
9296
@test fails(()->check_equal([[1.0], [2.0]], [[1.1], [2.0]]))
9397

9498
@test fails(()->check_equal(@thunk(10*[[1.0], [2.0]]), [[1.0], [2.0]]))
99+
100+
@test fails(()->check_equal(@not_implemented("a"), @not_implemented("b")))
95101
end
96102
@testset "type negative" begin
97103
@test fails() do # these have different primals so should not be equal

test/testers.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ f_noninferrable_pullback(x) = x
1414
f_noninferrable_thunk(x, y) = x + y
1515
f_inferrable_pullback_only(x) = x > 0 ? Float64(x) : Float32(x)
1616

17-
1817
function finplace!(x; y = [1])
1918
y[1] = 2
2019
x .*= y[1]
@@ -520,4 +519,11 @@ end
520519
end
521520
@test errors(() -> test_rrule(foo, [1.0, 2.0, 3.0], 2), "should use DoesNotExist()")
522521
end
522+
523+
@testset "NotImplemented" begin
524+
f_notimplemented(x, y) = (x + y, x - y)
525+
@scalar_rule f_notimplemented(x, y) (@not_implemented(""), 1) (1, -1)
526+
test_frule(f_notimplemented, randn(), randn())
527+
test_rrule(f_notimplemented, randn(), randn())
528+
end
523529
end

0 commit comments

Comments
 (0)