Skip to content

Commit 5882876

Browse files
authored
Merge pull request #15 from theogf/fixdiagmatrix
Added adjoints for evaluate
2 parents 5e2b139 + 360e056 commit 5882876

File tree

7 files changed

+66
-3
lines changed

7 files changed

+66
-3
lines changed

.travis.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ os:
55
- osx
66
julia:
77
- 1.0
8+
- 1.2
9+
- 1.3
810
- nightly
11+
# because of Zygote needs to allow failing on nightly
12+
matrix:
13+
allow_failures:
14+
- julia: nightly
915
notifications:
1016
email: false
1117
after_success:

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,27 @@ uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
33
version = "0.2.2"
44

55
[deps]
6+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
67
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
910
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1011
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
12+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1113

1214
[compat]
15+
Compat = "2.2"
1316
Distances = "0.8"
1417
PDMats = "0.9"
1518
SpecialFunctions = "0"
1619
StatsFuns = "0.8"
20+
Zygote = "0.4"
1721
julia = "1.0"
1822

1923
[extras]
2024
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2125
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
26+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2227

2328
[targets]
24-
test = ["Random", "Test"]
29+
test = ["Random", "Test", "FiniteDifferences"]

src/KernelFunctions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ export KernelSum, KernelProduct
1414

1515
export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1616

17+
using Compat
1718
using Distances, LinearAlgebra
1819
using SpecialFunctions: lgamma, besselk
20+
using Zygote: @adjoint
1921
using StatsFuns: logtwo
2022
using PDMats: PDMat
2123

@@ -42,4 +44,6 @@ include("kernels/kernelproduct.jl")
4244

4345
include("generic.jl")
4446

47+
include("zygote_adjoints.jl")
48+
4549
end

src/matrix/kernelmatrix.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ function kerneldiagmatrix(
108108
)
109109
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
110110
if obsdim == 1
111-
[@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
111+
@compat eachrow(X) .|> x->_kernel(κ,x,x) #[@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
112112
elseif obsdim == 2
113-
[@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
113+
@compat eachcol(X) .|> x->_kernel(κ,x,x) #[@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
114114
end
115115
end
116116

src/zygote_adjoints.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@adjoint function evaluate(s::SqEuclidean, x::AbstractVector, y::AbstractVector)
2+
δ = x .- y
3+
sum(abs2, δ), Δ -> begin
4+
= (2 * Δ) .* δ
5+
(nothing, x̄, -x̄)
6+
end
7+
end
8+
9+
@adjoint function evaluate(s::Euclidean, x::AbstractVector, y::AbstractVector)
10+
D = x.-y
11+
δ = sqrt(sum(abs2,D))
12+
δ, Δ -> begin
13+
= Δ .* D /+ eps(δ))
14+
(nothing, x̄, -x̄)
15+
end
16+
end
17+
18+
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector)
19+
dot(x,y), Δ -> begin
20+
(nothing, Δ.*y, Δ.*x)
21+
end
22+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ include("test_transform.jl")
1111
include("test_distances.jl")
1212
include("test_kernels.jl")
1313
include("test_generic.jl")
14+
include("test_adjoints.jl")
1415
#include("types.jl")
1516
end

test/test_adjoints.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using Zygote
2+
using FiniteDifferences
3+
using KernelFunctions
4+
using Distances
5+
using Test
6+
7+
x = rand(5)
8+
y = rand(5)
9+
10+
@testset "Testing Zygote adjoints" begin
11+
gzeucl = first(Zygote.gradient(xy->evaluate(Euclidean(),xy[1],xy[2]),[x,y]))
12+
gzsqeucl = first(Zygote.gradient(xy->evaluate(SqEuclidean(),xy[1],xy[2]),[x,y]))
13+
gzdotprod = first(Zygote.gradient(xy->evaluate(KernelFunctions.DotProduct(),xy[1],xy[2]),[x,y]))
14+
15+
FDM = central_fdm(5,1)
16+
17+
gfeucl = collect(first(FiniteDifferences.grad(FDM,xy->evaluate(Euclidean(),xy[1],xy[2]),(x,y))))
18+
gfsqeucl = collect(first(FiniteDifferences.grad(FDM,xy->evaluate(SqEuclidean(),xy[1],xy[2]),(x,y))))
19+
gfdotprod =collect(first(FiniteDifferences.grad(FDM,xy->evaluate(KernelFunctions.DotProduct(),xy[1],xy[2]),(x,y))))
20+
21+
@test all(gzeucl .≈ gfeucl)
22+
@test all(gzsqeucl .≈ gfsqeucl)
23+
@test all(gzdotprod .≈ gfdotprod)
24+
25+
end

0 commit comments

Comments
 (0)