Skip to content

Commit 90ddcae

Browse files
authored
Merge pull request #6 from JuliaDiff/MB/v1
Initial Port from ChainRules.jl
2 parents 1c39ca4 + 963922d commit 90ddcae

File tree

6 files changed

+251
-1
lines changed

6 files changed

+251
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@ docs/site/
2222
# committed for packages, but should be committed for applications that require a static
2323
# environment.
2424
Manifest.toml
25+
26+
# JetBrains meta files
27+
.idea/*

.travis.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
language: julia
2+
os:
3+
- linux
4+
- osx
5+
julia:
6+
- 1.0
7+
- 1.3
8+
- nightly
9+
10+
notifications:
11+
email:
12+
recipients:
13+
14+
on_success: never
15+
on_failure: always
16+
if: type = cron
17+
matrix:
18+
allow_failures:
19+
- julia: nightly

Project.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name = "ChainRulesTestUtils"
2+
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3+
version = "0.1.0"
4+
5+
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
8+
9+
[compat]
10+
ChainRulesCore = "0.5, 0.6"
11+
FiniteDifferences = "0.7, 0.8, 0.9"
12+
julia = "1"
13+
14+
[extras]
15+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
18+
[targets]
19+
test = ["Random", "Test"]

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
# ChainRulesTestUtils.jl
1+
# ChainRulesTestUtils.jl
2+
3+
[![Travis](https://travis-ci.org/JuliaDiff/ChainRulesTestUtils.jl.svg?branch=master)](https://travis-ci.org/JuliaDiff/ChainRulesTestUtils.jl)
4+
5+
`ChainRulesTestUtils.jl` provides a variety of common utilities for testing forward- and reverse- primitives.

src/ChainRulesTestUtils.jl

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
module ChainRulesTestUtils
2+
3+
using ChainRulesCore
4+
using ChainRulesCore: frule, rrule
5+
using ChainRulesCore: AbstractDifferential
6+
using FiniteDifferences
7+
using Test
8+
9+
const _fdm = central_fdm(5, 1)
10+
11+
export test_scalar, frule_test, rrule_test, isapprox, generate_well_conditioned_matrix
12+
13+
Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`")
14+
Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)
15+
16+
function _make_fdm_call(fdm, f, ȳ, xs, ignores)
17+
sig = Expr(:tuple)
18+
call = Expr(:call, f)
19+
newxs = Any[]
20+
arginds = Int[]
21+
i = 1
22+
for (x, ignore) in zip(xs, ignores)
23+
if ignore
24+
push!(call.args, x)
25+
else
26+
push!(call.args, Symbol(:x, i))
27+
push!(sig.args, Symbol(:x, i))
28+
push!(newxs, x)
29+
push!(arginds, i)
30+
end
31+
i += 1
32+
end
33+
fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...)))
34+
fd = eval(fdexpr)
35+
fd isa Tuple || (fd = (fd,))
36+
args = Any[nothing for _ in 1:length(xs)]
37+
for (dx, ind) in zip(fd, arginds)
38+
args[ind] = dx
39+
end
40+
return (args...,)
41+
end
42+
43+
# Useful for LinearAlgebra tests
44+
function generate_well_conditioned_matrix(rng, N)
45+
A = randn(rng, N, N)
46+
return A * A' + I
47+
end
48+
49+
"""
50+
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
51+
52+
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
53+
at input point `x` to confirm that there are correct `frule` and `rrule`s provided.
54+
55+
# Arguments
56+
- `f`: Function for which the `frule` and `rrule` should be tested.
57+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
58+
59+
All keyword arguments except for `fdm` is passed to `isapprox`.
60+
"""
61+
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
62+
ensure_not_running_on_functor(f, "test_scalar")
63+
64+
r_res = rrule(f, x)
65+
f_res = frule(f, x, Zero(), 1)
66+
@test r_res !== nothing # Check the rule was defined
67+
@test f_res !== nothing
68+
r_fx, prop_rule = r_res
69+
f_fx, f_∂x = f_res
70+
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in (
71+
(rrule, r_fx, prop_rule(1)),
72+
(frule, f_fx, f_∂x)
73+
)
74+
@test fx == f(x) # Check we still get the normal value, right
75+
76+
if rule == rrule
77+
∂self, ∂x = ∂x
78+
@test ∂self === NO_FIELDS
79+
end
80+
@test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...)
81+
end
82+
end
83+
84+
function ensure_not_running_on_functor(f, name)
85+
# if x itself is a Type, then it is a constructor, thus not a functor.
86+
# This also catchs UnionAll constructors which have a `:var` and `:body` fields
87+
f isa Type && return
88+
89+
if fieldcount(typeof(f)) > 0
90+
throw(ArgumentError(
91+
"$name cannot be used on closures/functors (such as $f)"
92+
))
93+
end
94+
end
95+
96+
"""
97+
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
98+
99+
# Arguments
100+
- `f`: Function for which the `frule` should be tested.
101+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
102+
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
103+
104+
All keyword arguments except for `fdm` are passed to `isapprox`.
105+
"""
106+
function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
107+
return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm, kwargs...)
108+
end
109+
110+
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
111+
ensure_not_running_on_functor(f, "frule_test")
112+
xs, ẋs = collect(zip(xẋs...))
113+
Ω, dΩ_ad = frule(f, xs..., NO_FIELDS, ẋs...)
114+
@test f(xs...) == Ω
115+
116+
# Correctness testing via finite differencing.
117+
dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs))
118+
@test isapprox(
119+
collect(extern.(dΩ_ad)), # Use collect so can use vector equality
120+
collect(dΩ_fd);
121+
rtol=rtol,
122+
atol=atol,
123+
kwargs...
124+
)
125+
end
126+
127+
128+
"""
129+
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
130+
131+
# Arguments
132+
- `f`: Function to which rule should be applied.
133+
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
134+
Should be same structure as `f(x)` (so if multiple returns should be a tuple)
135+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
136+
- `x̄`: currently accumulated adjoint (should generally be set randomly).
137+
138+
All keyword arguments except for `fdm` are passed to `isapprox`.
139+
"""
140+
function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
141+
ensure_not_running_on_functor(f, "rrule_test")
142+
143+
# Check correctness of evaluation.
144+
fx, pullback = rrule(f, x)
145+
@test collect(fx) collect(f(x)) # use collect so can do vector equality
146+
(∂self, x̄_ad) = if fx isa Tuple
147+
# If the function returned multiple values,
148+
# then it must have multiple seeds for propagating backwards
149+
pullback(ȳ...)
150+
else
151+
pullback(ȳ)
152+
end
153+
154+
@test ∂self === NO_FIELDS # No internal fields
155+
# Correctness testing via finite differencing.
156+
x̄_fd = j′vp(fdm, f, ȳ, x)
157+
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
158+
end
159+
160+
# case where `f` takes multiple arguments
161+
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
162+
ensure_not_running_on_functor(f, "rrule_test")
163+
164+
# Check correctness of evaluation.
165+
xs, x̄s = collect(zip(xx̄s...))
166+
y, pullback = rrule(f, xs...)
167+
@test f(xs...) == y
168+
169+
@assert !(isa(ȳ, Thunk))
170+
∂s = pullback(ȳ)
171+
∂self = ∂s[1]
172+
x̄s_ad = ∂s[2:end]
173+
@test ∂self === NO_FIELDS
174+
175+
# Correctness testing via finite differencing.
176+
x̄s_fd = j′vp(fdm, f, ȳ, xs...)
177+
map(x̄s_ad, x̄s_fd) do x̄_ad, x̄_fd
178+
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
179+
end
180+
end
181+
182+
end # module

test/runtests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using ChainRulesCore
2+
using ChainRulesTestUtils
3+
using Random
4+
using Test
5+
6+
@testset "ChainRulesTestUtils.jl" begin
7+
double(x) = 2x
8+
@scalar_rule(double(x), 2)
9+
test_scalar(double, 2)
10+
11+
fst(x, y) = x
12+
ChainRulesCore.frule(::typeof(fst), x, y, _, dx, dy) = (x, dx)
13+
14+
function ChainRulesCore.rrule(::typeof(fst), x, y)
15+
function fst_pullback(Δx)
16+
return (NO_FIELDS, Δx, Zero())
17+
end
18+
return x, fst_pullback
19+
end
20+
21+
frule_test(fst, (2, 4.0), (3, 5.0))
22+
rrule_test(fst, rand(), (2.0, 4.0), (3.0, 5.0))
23+
end

0 commit comments

Comments
 (0)