Skip to content

Commit 8c71c8a

Browse files
committed
Add Eucldiean rules + refactor Project.toml
1 parent 939c154 commit 8c71c8a

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,10 @@ version = "0.1.0"
55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
9+
10+
[extras]
811
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12+
13+
[targets]
14+
test = ["Test"]

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Distances, LinearAlgebra
88
const defaultobs = 2
99
abstract type Kernel{T<:Real} end
1010

11+
include("zygote_rules.jl")
1112
include("utils.jl")
1213
include("common.jl")
1314
include("kernelmatrix.jl")

src/zygote_rules.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using ZygoteRules
2+
3+
@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
4+
d = colwise(s, x, y)
5+
return d, function::AbstractVector)
6+
=./ d)' .* (x .- y)
7+
return nothing, x̄, -
8+
end
9+
end
10+
11+
@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
12+
@assert dims == 2
13+
D, back = Zygote.forward((X, Y)->pairwise(SqEuclidean(), X, Y; dims=2), X, Y)
14+
D .= sqrt.(D)
15+
return D, Δ -> (nothing, back./ (2 .* D))...)
16+
end
17+
18+
@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
19+
@assert dims == 2
20+
D, back = Zygote.forward(X->pairwise(SqEuclidean(), X; dims=2), X)
21+
D .= sqrt.(D)
22+
return D, function(Δ)
23+
Δ = Δ ./ (2 .* D)
24+
Δ[diagind(Δ)] .= 0
25+
return (nothing, first(back(Δ)))
26+
end
27+
end

0 commit comments

Comments
 (0)