Skip to content

Commit 8bb99ea

Browse files
Merge pull request #20 from JuliaDiff/npr/fdm0.9
Update FiniteDifferences
2 parents 41f6ae3 + 5594a8c commit 8bb99ea

File tree

3 files changed

+52
-19
lines changed

3 files changed

+52
-19
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.2.0"
3+
version = "0.2.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
78
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1011

1112
[compat]
12-
ChainRulesCore = "0.7"
13-
FiniteDifferences = "0.7, 0.8, 0.9"
13+
ChainRulesCore = "0.7.1"
14+
Compat = "3"
15+
FiniteDifferences = "0.9"
1416
julia = "1"
1517

1618
[extras]

src/ChainRulesTestUtils.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ChainRulesTestUtils
22

33
using ChainRulesCore
44
using ChainRulesCore: frule, rrule
5-
using ChainRulesCore: AbstractDifferential
5+
using Compat: only
66
using FiniteDifferences
77
using LinearAlgebra
88
using Test
@@ -11,8 +11,12 @@ const _fdm = central_fdm(5, 1)
1111

1212
export test_scalar, frule_test, rrule_test, isapprox, generate_well_conditioned_matrix
1313

14+
# TODO: reconsider these https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/7
15+
Base.isapprox(a, b::Union{AbstractZero, AbstractThunk}; kwargs...) = isapprox(b, a; kwargs...)
16+
Base.isapprox(d_ad::AbstractThunk, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)
1417
Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`")
15-
Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)
18+
# Call `all` to handle the case where `Zero` is standing in for a non-scalar zero
19+
Base.isapprox(d_ad::Zero, d_fd; kwargs...) = all(isapprox.(extern(d_ad), d_fd; kwargs...))
1620

1721
"""
1822
_make_fdm_call(fdm, f, ȳ, xs, ignores) -> Tuple
@@ -49,7 +53,6 @@ function _make_fdm_call(fdm, f, ȳ, xs, ignores)
4953
end
5054
fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...)))
5155
fd = eval(fdexpr)
52-
fd isa Tuple || (fd = (fd,))
5356
args = Any[nothing for _ in 1:length(xs)]
5457
for (dx, ind) in zip(fd, arginds)
5558
args[ind] = dx
@@ -170,7 +173,7 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
170173

171174
@test ∂self === NO_FIELDS # No internal fields
172175
# Correctness testing via finite differencing.
173-
x̄_fd = j′vp(fdm, f, ȳ, x)
176+
x̄_fd = only(j′vp(fdm, f, ȳ, x)) # j′vp returns a tuple, but `f` is a unary function.
174177
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
175178
end
176179

test/runtests.jl

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,48 @@ using Random
44
using Test
55

66
@testset "ChainRulesTestUtils.jl" begin
7-
double(x) = 2x
8-
@scalar_rule(double(x), 2)
9-
test_scalar(double, 2)
10-
11-
fst(x, y) = x
12-
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)
7+
@testset "test_scalar" begin
8+
double(x) = 2x
9+
@scalar_rule(double(x), 2)
10+
test_scalar(double, 2)
11+
end
1312

14-
function ChainRulesCore.rrule(::typeof(fst), x, y)
15-
function fst_pullback(Δx)
16-
return (NO_FIELDS, Δx, Zero())
13+
@testset "unary: identity(x)" begin
14+
function ChainRulesCore.frule((_, ẏ), ::typeof(identity), x)
15+
return x, ẏ
16+
end
17+
function ChainRulesCore.rrule(::typeof(identity), x)
18+
function identity_pullback(ȳ)
19+
return (NO_FIELDS, ȳ)
20+
end
21+
return x, identity_pullback
22+
end
23+
@testset "frule_test" begin
24+
frule_test(identity, (randn(), randn()))
25+
frule_test(identity, (randn(4), randn(4)))
26+
end
27+
@testset "rrule_test" begin
28+
rrule_test(identity, randn(), (randn(), randn()))
29+
rrule_test(identity, randn(4), (randn(4), randn(4)))
1730
end
18-
return x, fst_pullback
1931
end
2032

21-
frule_test(fst, (2, 4.0), (3, 5.0))
22-
rrule_test(fst, rand(), (2.0, 4.0), (3.0, 5.0))
33+
@testset "binary: fst(x, y)" begin
34+
fst(x, y) = x
35+
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)
36+
function ChainRulesCore.rrule(::typeof(fst), x, y)
37+
function fst_pullback(Δx)
38+
return (NO_FIELDS, Δx, Zero())
39+
end
40+
return x, fst_pullback
41+
end
42+
@testset "frule_test" begin
43+
frule_test(fst, (2, 4.0), (3, 5.0))
44+
frule_test(fst, (randn(4), randn(4)), (randn(4), randn(4)))
45+
end
46+
@testset "rrule_test" begin
47+
rrule_test(fst, rand(), (2.0, 4.0), (3.0, 5.0))
48+
rrule_test(fst, randn(4), (randn(4), randn(4)), (randn(4), randn(4)))
49+
end
50+
end
2351
end

0 commit comments

Comments
 (0)