Skip to content

Commit ed0f44c

Browse files
committed
Test rules + refactor test file
1 parent 8c71c8a commit ed0f44c

File tree

5 files changed

+117
-7
lines changed

5 files changed

+117
-7
lines changed

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@ version = "0.1.0"
55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9+
10+
[compat]
11+
FiniteDifferences = ">= 0.7.2"
912

1013
[extras]
14+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
15+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1116
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1217

1318
[targets]
14-
test = ["Test"]
19+
test = ["FiniteDifferences", "Random", "Test"]

src/zygote_rules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ZygoteRules
1+
using Zygote: @adjoint, forward
22

33
@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
44
d = colwise(s, x, y)
@@ -10,14 +10,14 @@ end
1010

1111
@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
1212
@assert dims == 2
13-
D, back = Zygote.forward((X, Y)->pairwise(SqEuclidean(), X, Y; dims=2), X, Y)
13+
D, back = forward((X, Y)->pairwise(SqEuclidean(), X, Y; dims=2), X, Y)
1414
D .= sqrt.(D)
1515
return D, Δ -> (nothing, back./ (2 .* D))...)
1616
end
1717

1818
@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
1919
@assert dims == 2
20-
D, back = Zygote.forward(X->pairwise(SqEuclidean(), X; dims=2), X)
20+
D, back = forward(X->pairwise(SqEuclidean(), X; dims=2), X)
2121
D .= sqrt.(D)
2222
return D, function(Δ)
2323
Δ = Δ ./ (2 .* D)

test/runtests.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
using Test
22
using KernelFunctions
33
using Distances
4+
using FiniteDifferences
5+
using Random
6+
using Zygote
47

5-
include("kernelmatrix.jl")
6-
include("constructors.jl")
8+
# Helpful functionality for writing tests.
9+
include("test_util.jl")
10+
11+
@testset "KernelFunctions" begin
12+
include("zygote_rules.jl")
13+
include("kernelmatrix.jl")
14+
include("constructors.jl")
15+
end

test/test_util.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using FiniteDifferences: j′vp
2+
3+
# Default tolerances for testing Zygote against FiniteDifferences.
4+
const _rtol = 1e-10
5+
const _atol = 1e-10
6+
7+
8+
9+
#
10+
# Print stuff for debugging
11+
#
12+
13+
function print_adjoints(adjoint_ad, adjoint_fd, rtol, atol)
14+
@show typeof(adjoint_ad), typeof(adjoint_fd)
15+
adjoint_ad, adjoint_fd = to_vec(adjoint_ad)[1], to_vec(adjoint_fd)[1]
16+
println("atol is $atol, rtol is $rtol")
17+
println("ad, fd, abs, rel")
18+
abs_err = abs.(adjoint_ad .- adjoint_fd)
19+
rel_err = abs.((adjoint_ad .- adjoint_fd) ./ adjoint_ad)
20+
display([adjoint_ad adjoint_fd abs_err rel_err])
21+
println()
22+
end
23+
24+
25+
26+
#
27+
# Version of isapprox that works for lots of types.
28+
#
29+
30+
function fd_isapprox(x_ad::Nothing, x_fd, rtol, atol)
31+
return fd_isapprox(x_fd, zero(x_fd), rtol, atol)
32+
end
33+
function fd_isapprox(x_ad::AbstractArray, x_fd::AbstractArray, rtol, atol)
34+
return all(fd_isapprox.(x_ad, x_fd, rtol, atol))
35+
end
36+
function fd_isapprox(x_ad::Real, x_fd::Real, rtol, atol)
37+
return isapprox(x_ad, x_fd; rtol=rtol, atol=atol)
38+
end
39+
function fd_isapprox(x_ad::NamedTuple, x_fd, rtol, atol)
40+
f = (x_ad, x_fd)->fd_isapprox(x_ad, x_fd, rtol, atol)
41+
return all([f(getfield(x_ad, key), getfield(x_fd, key)) for key in keys(x_ad)])
42+
end
43+
function fd_isapprox(x_ad::Tuple, x_fd::Tuple, rtol, atol)
44+
return all(map((x, x′)->fd_isapprox(x, x′, rtol, atol), x_ad, x_fd))
45+
end
46+
function fd_isapprox(x_ad::Dict, x_fd::Dict, rtol, atol)
47+
return all([fd_isapprox(get(()->nothing, x_ad, key), x_fd[key], rtol, atol) for
48+
key in keys(x_fd)])
49+
end
50+
51+
52+
53+
#
54+
# Check Zygote against FiniteDifferences.
55+
#
56+
57+
function adjoint_test(
58+
f, ȳ, x...;
59+
rtol=_rtol,
60+
atol=_atol,
61+
fdm=FiniteDifferences.Central(5, 1),
62+
print_results=false,
63+
)
64+
65+
# Compute forwards-pass and j′vp.
66+
y, back = Zygote.forward(f, x...)
67+
adj_ad = back(ȳ)
68+
adj_fd = j′vp(fdm, f, ȳ, x...)
69+
70+
# If unary, pull out first thing from ad.
71+
adj_ad = length(x) == 1 ? first(adj_ad) : adj_ad
72+
73+
# Check that forwards-pass agrees with plain forwards-pass.
74+
@test y f(x...)
75+
76+
# Check that ad and fd adjoints (approximately) agree.
77+
print_results && print_adjoints(adj_ad, adj_fd, rtol, atol)
78+
@test fd_isapprox(adj_ad, adj_fd, rtol, atol)
79+
end

test/zygote_rules.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@testset "zygote_rules" begin
2+
@testset "colwise(::Euclidean, X, Y; dims=2)" begin
3+
rng, D, P = MersenneTwister(123456), 2, 3
4+
X, Y, D̄ = randn(rng, D, P), randn(rng, D, P), randn(rng, P)
5+
adjoint_test((X, Y)->colwise(Euclidean(), X, Y), D̄, X, Y)
6+
end
7+
@testset "pairwise(::Euclidean, X, Y; dims=2)" begin
8+
rng, D, P, Q = MersenneTwister(123456), 2, 3, 5
9+
X, Y, D̄ = randn(rng, D, P), randn(rng, D, Q), randn(rng, P, Q)
10+
adjoint_test((X, Y)->pairwise(Euclidean(), X, Y; dims=2), D̄, X, Y)
11+
end
12+
@testset "pairwise(::Euclidean, X; dims=2)" begin
13+
rng, D, P = MersenneTwister(123456), 2, 3
14+
X, D̄ = randn(rng, D, P), randn(rng, P, P)
15+
adjoint_test(X->pairwise(Euclidean(), X; dims=2), D̄, X)
16+
end
17+
end

0 commit comments

Comments
 (0)