Skip to content

Commit 2bfd596

Browse files
committed
Added adjoints to fix the evaluate problem
1 parent 4fac93f commit 2bfd596

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
99
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1010
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
11+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1112

1213
[compat]
1314
Distances = "0.8"
1415
PDMats = "0.9"
1516
SpecialFunctions = "0"
1617
StatsFuns = "0.8"
18+
Zygote = "0.4"
1719
julia = "1.0"
1820

1921
[extras]

src/KernelFunctions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export Transform, SelectTransform, ChainTransform, ScaleTransform, LowRankTransf
1616

1717
using Distances, LinearAlgebra
1818
using SpecialFunctions: lgamma, besselk
19+
using Zygote: @adjoint
1920
using StatsFuns: logtwo
2021
using PDMats: PDMat
2122

@@ -42,4 +43,6 @@ include("kernels/kernelproduct.jl")
4243

4344
include("generic.jl")
4445

46+
include("zygote_adjoints.jl")
47+
4548
end

src/zygote_adjoints.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@adjoint function evaluate(s::SqEuclidean, x::AbstractVector, y::AbstractVector)
2+
δ = x .- y
3+
sum(abs2, δ), Δ -> begin
4+
= (2 * Δ) .* δ
5+
(nothing, x̄, -x̄)
6+
end
7+
end
8+
9+
@adjoint function evaluate(s::Euclidean, x::AbstractVector, y::AbstractVector)
10+
D = x.-y
11+
δ = sqrt(sum(abs2,D))
12+
δ, Δ -> begin
13+
= Δ .* D /+ eps(δ))
14+
(nothing, x̄, -x̄)
15+
end
16+
end
17+
18+
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector)
19+
dot(x,y), Δ -> begin
20+
(nothing, Δ.*y, Δ.*x)
21+
end
22+
end

0 commit comments

Comments
 (0)