Skip to content

Commit e237992

Browse files
authored
Merge pull request #5 from theogf/wct/zygote-euclidean
Add Euclidean rules + refactor Project.toml
2 parents 939c154 + ed0f44c commit e237992

File tree

6 files changed

+146
-2
lines changed

6 files changed

+146
-2
lines changed

Project.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,15 @@ version = "0.1.0"
55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9+
10+
[compat]
11+
FiniteDifferences = ">= 0.7.2"
12+
13+
[extras]
14+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
15+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
816
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
18+
[targets]
19+
test = ["FiniteDifferences", "Random", "Test"]

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Distances, LinearAlgebra
88
const defaultobs = 2
99
abstract type Kernel{T<:Real} end
1010

11+
include("zygote_rules.jl")
1112
include("utils.jl")
1213
include("common.jl")
1314
include("kernelmatrix.jl")

src/zygote_rules.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using Zygote: @adjoint, forward
2+
3+
@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
4+
d = colwise(s, x, y)
5+
return d, function::AbstractVector)
6+
=./ d)' .* (x .- y)
7+
return nothing, x̄, -
8+
end
9+
end
10+
11+
@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
12+
@assert dims == 2
13+
D, back = forward((X, Y)->pairwise(SqEuclidean(), X, Y; dims=2), X, Y)
14+
D .= sqrt.(D)
15+
return D, Δ -> (nothing, back./ (2 .* D))...)
16+
end
17+
18+
@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
19+
@assert dims == 2
20+
D, back = forward(X->pairwise(SqEuclidean(), X; dims=2), X)
21+
D .= sqrt.(D)
22+
return D, function(Δ)
23+
Δ = Δ ./ (2 .* D)
24+
Δ[diagind(Δ)] .= 0
25+
return (nothing, first(back(Δ)))
26+
end
27+
end

test/runtests.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
using Test
22
using KernelFunctions
33
using Distances
4+
using FiniteDifferences
5+
using Random
6+
using Zygote
47

5-
include("kernelmatrix.jl")
6-
include("constructors.jl")
8+
# Helpful functionality for writing tests.
9+
include("test_util.jl")
10+
11+
@testset "KernelFunctions" begin
12+
include("zygote_rules.jl")
13+
include("kernelmatrix.jl")
14+
include("constructors.jl")
15+
end

test/test_util.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using FiniteDifferences: j′vp
2+
3+
# Default tolerances for testing Zygote against FiniteDifferences.
4+
const _rtol = 1e-10
5+
const _atol = 1e-10
6+
7+
8+
9+
#
10+
# Print stuff for debugging
11+
#
12+
13+
function print_adjoints(adjoint_ad, adjoint_fd, rtol, atol)
14+
@show typeof(adjoint_ad), typeof(adjoint_fd)
15+
adjoint_ad, adjoint_fd = to_vec(adjoint_ad)[1], to_vec(adjoint_fd)[1]
16+
println("atol is $atol, rtol is $rtol")
17+
println("ad, fd, abs, rel")
18+
abs_err = abs.(adjoint_ad .- adjoint_fd)
19+
rel_err = abs.((adjoint_ad .- adjoint_fd) ./ adjoint_ad)
20+
display([adjoint_ad adjoint_fd abs_err rel_err])
21+
println()
22+
end
23+
24+
25+
26+
#
27+
# Version of isapprox that works for lots of types.
28+
#
29+
30+
function fd_isapprox(x_ad::Nothing, x_fd, rtol, atol)
31+
return fd_isapprox(x_fd, zero(x_fd), rtol, atol)
32+
end
33+
function fd_isapprox(x_ad::AbstractArray, x_fd::AbstractArray, rtol, atol)
34+
return all(fd_isapprox.(x_ad, x_fd, rtol, atol))
35+
end
36+
function fd_isapprox(x_ad::Real, x_fd::Real, rtol, atol)
37+
return isapprox(x_ad, x_fd; rtol=rtol, atol=atol)
38+
end
39+
function fd_isapprox(x_ad::NamedTuple, x_fd, rtol, atol)
40+
f = (x_ad, x_fd)->fd_isapprox(x_ad, x_fd, rtol, atol)
41+
return all([f(getfield(x_ad, key), getfield(x_fd, key)) for key in keys(x_ad)])
42+
end
43+
function fd_isapprox(x_ad::Tuple, x_fd::Tuple, rtol, atol)
44+
return all(map((x, x′)->fd_isapprox(x, x′, rtol, atol), x_ad, x_fd))
45+
end
46+
function fd_isapprox(x_ad::Dict, x_fd::Dict, rtol, atol)
47+
return all([fd_isapprox(get(()->nothing, x_ad, key), x_fd[key], rtol, atol) for
48+
key in keys(x_fd)])
49+
end
50+
51+
52+
53+
#
54+
# Check Zygote against FiniteDifferences.
55+
#
56+
57+
function adjoint_test(
58+
f, ȳ, x...;
59+
rtol=_rtol,
60+
atol=_atol,
61+
fdm=FiniteDifferences.Central(5, 1),
62+
print_results=false,
63+
)
64+
65+
# Compute forwards-pass and j′vp.
66+
y, back = Zygote.forward(f, x...)
67+
adj_ad = back(ȳ)
68+
adj_fd = j′vp(fdm, f, ȳ, x...)
69+
70+
# If unary, pull out first thing from ad.
71+
adj_ad = length(x) == 1 ? first(adj_ad) : adj_ad
72+
73+
# Check that forwards-pass agrees with plain forwards-pass.
74+
@test y f(x...)
75+
76+
# Check that ad and fd adjoints (approximately) agree.
77+
print_results && print_adjoints(adj_ad, adj_fd, rtol, atol)
78+
@test fd_isapprox(adj_ad, adj_fd, rtol, atol)
79+
end

test/zygote_rules.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@testset "zygote_rules" begin
2+
@testset "colwise(::Euclidean, X, Y; dims=2)" begin
3+
rng, D, P = MersenneTwister(123456), 2, 3
4+
X, Y, D̄ = randn(rng, D, P), randn(rng, D, P), randn(rng, P)
5+
adjoint_test((X, Y)->colwise(Euclidean(), X, Y), D̄, X, Y)
6+
end
7+
@testset "pairwise(::Euclidean, X, Y; dims=2)" begin
8+
rng, D, P, Q = MersenneTwister(123456), 2, 3, 5
9+
X, Y, D̄ = randn(rng, D, P), randn(rng, D, Q), randn(rng, P, Q)
10+
adjoint_test((X, Y)->pairwise(Euclidean(), X, Y; dims=2), D̄, X, Y)
11+
end
12+
@testset "pairwise(::Euclidean, X; dims=2)" begin
13+
rng, D, P = MersenneTwister(123456), 2, 3
14+
X, D̄ = randn(rng, D, P), randn(rng, P, P)
15+
adjoint_test(X->pairwise(Euclidean(), X; dims=2), D̄, X)
16+
end
17+
end

0 commit comments

Comments
 (0)