Skip to content

Commit b6876a1

Browse files
authored
Merge pull request #123 from theogf/fix_transform
Fixes #115
2 parents 602f7a6 + dd5fa10 commit b6876a1

File tree

9 files changed

+100
-62
lines changed

9 files changed

+100
-62
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ abstract type BaseKernel <: Kernel end
5050
abstract type SimpleKernel <: BaseKernel end
5151

5252
include("utils.jl")
53+
include("distances/pairwise.jl")
5354
include("distances/dotproduct.jl")
5455
include("distances/delta.jl")
5556
include("distances/sinus.jl")

src/distances/pairwise.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Add our own pairwise function to be able to apply it on vectors
2+
3+
pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) = broadcast(d, X, Y')
4+
5+
pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X)
6+
7+
function pairwise!(
8+
out::AbstractMatrix,
9+
d::PreMetric,
10+
X::AbstractVector,
11+
Y::AbstractVector,
12+
)
13+
broadcast!(d, out, X, Y')
14+
end
15+
16+
pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X)
17+
18+
function pairwise(d::PreMetric, x::AbstractVector{<:Real})
19+
return Distances.pairwise(d, reshape(x, :, 1); dims = 1)
20+
end
21+
22+
function pairwise(
23+
d::PreMetric,
24+
x::AbstractVector{<:Real},
25+
y::AbstractVector{<:Real},
26+
)
27+
return Distances.pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims = 1)
28+
end
29+
30+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
31+
return Distances.pairwise!(out, d, reshape(x, :, 1); dims = 1)
32+
end
33+
34+
function pairwise!(
35+
out::AbstractMatrix,
36+
d::PreMetric,
37+
x::AbstractVector{<:Real},
38+
y::AbstractVector{<:Real},
39+
)
40+
return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
41+
end

src/generic.jl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,3 @@ end
2020

2121
# Fallback implementation of evaluate for `SimpleKernel`s.
2222
(k::SimpleKernel)(x, y) = kappa(k, evaluate(metric(k), x, y))
23-
24-
# This is type piracy. We should not doing this.
25-
function Distances.pairwise(d::PreMetric, x::AbstractVector{<:Real})
26-
return pairwise(d, reshape(x, :, 1); dims=1)
27-
end
28-
29-
function Distances.pairwise(
30-
d::PreMetric,
31-
x::AbstractVector{<:Real},
32-
y::AbstractVector{<:Real},
33-
)
34-
return pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
35-
end
36-
37-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real})
38-
return pairwise!(out, d, reshape(x, :, 1); dims=1)
39-
end
40-
41-
function Distances.pairwise!(
42-
out::AbstractMatrix,
43-
d::PreMetric,
44-
x::AbstractVector{<:Real},
45-
y::AbstractVector{<:Real},
46-
)
47-
return pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1)
48-
end

src/utils.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i))
4343

4444
dim(x::ColVecs) = size(x.X, 1)
4545

46-
Distances.pairwise(d::PreMetric, x::ColVecs) = pairwise(d, x.X; dims=2)
47-
Distances.pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = pairwise(d, x.X, y.X; dims=2)
48-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
49-
return pairwise!(out, d, x.X; dims=2)
46+
pairwise(d::PreMetric, x::ColVecs) = Distances.pairwise(d, x.X; dims=2)
47+
pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances.pairwise(d, x.X, y.X; dims=2)
48+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs)
49+
return Distances.pairwise!(out, d, x.X; dims=2)
5050
end
51-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
52-
return pairwise!(out, d, x.X, y.X; dims=2)
51+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs)
52+
return Distances.pairwise!(out, d, x.X, y.X; dims=2)
5353
end
5454

5555
"""
@@ -73,13 +73,13 @@ Base.getindex(D::RowVecs, i) = RowVecs(view(D.X, i, :))
7373

7474
dim(x::RowVecs) = size(x.X, 2)
7575

76-
Distances.pairwise(d::PreMetric, x::RowVecs) = pairwise(d, x.X; dims=1)
77-
Distances.pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = pairwise(d, x.X, y.X; dims=1)
78-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
79-
return pairwise!(out, d, x.X; dims=1)
76+
pairwise(d::PreMetric, x::RowVecs) = Distances.pairwise(d, x.X; dims=1)
77+
pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances.pairwise(d, x.X, y.X; dims=1)
78+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs)
79+
return Distances.pairwise!(out, d, x.X; dims=1)
8080
end
81-
function Distances.pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
82-
return pairwise!(out, d, x.X, y.X; dims=1)
81+
function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
82+
return Distances.pairwise!(out, d, x.X, y.X; dims=1)
8383
end
8484

8585
"""

src/zygote_adjoints.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
end
66
end
77

8-
@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
9-
D = pairwise(d, X, Y; dims = dims)
8+
@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
9+
D = Distances.pairwise(d, X, Y; dims = dims)
1010
if dims == 1
1111
return D, Δ -> (nothing, nothing, nothing)
1212
else
1313
return D, Δ -> (nothing, nothing, nothing)
1414
end
1515
end
1616

17-
@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2)
18-
D = pairwise(d, X; dims = dims)
17+
@adjoint function Distances.pairwise(d::Delta, X::AbstractMatrix; dims=2)
18+
D = Distances.pairwise(d, X; dims = dims)
1919
if dims == 1
2020
return D, Δ -> (nothing, nothing)
2121
else
@@ -30,17 +30,17 @@ end
3030
end
3131
end
3232

33-
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
34-
D = pairwise(d, X, Y; dims = dims)
33+
@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
34+
D = Distances.pairwise(d, X, Y; dims = dims)
3535
if dims == 1
3636
return D, Δ -> (nothing, Δ * Y, (X' * Δ)')
3737
else
3838
return D, Δ -> (nothing, (Δ * Y')', X * Δ)
3939
end
4040
end
4141

42-
@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
43-
D = pairwise(d, X; dims = dims)
42+
@adjoint function Distances.pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
43+
D = Distances.pairwise(d, X; dims = dims)
4444
if dims == 1
4545
return D, Δ -> (nothing, 2 * Δ * X)
4646
else

test/distances/pairwise.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
@testset "pairwise" begin
2+
rng = MersenneTwister(123456)
3+
d = SqEuclidean()
4+
Ns = (4, 5)
5+
D = 3
6+
x = [randn(rng, D) for _ in 1:Ns[1]]
7+
y = [randn(rng, D) for _ in 1:Ns[2]]
8+
X = hcat(x...)
9+
Y = hcat(y...)
10+
K = zeros(Ns)
11+
12+
@test KernelFunctions.pairwise(d, x, y) pairwise(d, X, Y, dims=2)
13+
@test KernelFunctions.pairwise(d, x) pairwise(d, X, dims=2)
14+
KernelFunctions.pairwise!(K, d, x, y)
15+
@test K pairwise(d, X, Y, dims=2)
16+
K = zeros(Ns[1], Ns[1])
17+
KernelFunctions.pairwise!(K, d, x)
18+
@test K pairwise(d, X, dims=2)
19+
20+
x = randn(rng, 10)
21+
X = reshape(x, :, 1)
22+
y = randn(rng, 11)
23+
Y = reshape(y, :, 1)
24+
K = zeros(10, 11)
25+
@test KernelFunctions.pairwise(d, x, y) pairwise(d, X, Y; dims=1)
26+
@test KernelFunctions.pairwise(d, x) pairwise(d, X; dims=1)
27+
KernelFunctions.pairwise!(K, d, x, y)
28+
@test K pairwise(d, X, Y, dims=1)
29+
end

test/generic.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,4 @@
33
@test length(k) == 1
44
@test iterate(k) == (k,nothing)
55
@test iterate(k,1) == nothing
6-
7-
rng = MersenneTwister(123456)
8-
x = randn(rng, 10)
9-
X = reshape(x, :, 1)
10-
y = randn(rng, 11)
11-
Y = reshape(y, :, 1)
12-
@test pairwise(SqEuclidean(), x, y) pairwise(SqEuclidean(), X, Y; dims=1)
13-
@test pairwise(SqEuclidean(), x) pairwise(SqEuclidean(), X; dims=1)
146
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ using KernelFunctions: metric, kappa, ColVecs, RowVecs
4848
include("utils_AD.jl")
4949

5050
@testset "distances" begin
51+
include(joinpath("distances", "pairwise.jl"))
5152
include(joinpath("distances", "dotproduct.jl"))
5253
include(joinpath("distances", "delta.jl"))
5354
include(joinpath("distances", "sinus.jl"))

test/utils.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
Y = randn(rng, D, N + 1)
2424
DY = ColVecs(Y)
25-
@test pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=2)
26-
@test pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=2)
25+
@test KernelFunctions.pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=2)
26+
@test KernelFunctions.pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=2)
2727
K = zeros(N, N)
28-
pairwise!(K, SqEuclidean(), DX)
28+
KernelFunctions.pairwise!(K, SqEuclidean(), DX)
2929
@test K pairwise(SqEuclidean(), X; dims=2)
3030
K = zeros(N, N + 1)
31-
pairwise!(K, SqEuclidean(), DX, DY)
31+
KernelFunctions.pairwise!(K, SqEuclidean(), DX, DY)
3232
@test K pairwise(SqEuclidean(), X, Y; dims=2)
3333

3434
let
@@ -56,13 +56,13 @@
5656

5757
Y = randn(rng, D + 1, N)
5858
DY = RowVecs(Y)
59-
@test pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=1)
60-
@test pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=1)
59+
@test KernelFunctions.pairwise(SqEuclidean(), DX) pairwise(SqEuclidean(), X; dims=1)
60+
@test KernelFunctions.pairwise(SqEuclidean(), DX, DY) pairwise(SqEuclidean(), X, Y; dims=1)
6161
K = zeros(D, D)
62-
pairwise!(K, SqEuclidean(), DX)
62+
KernelFunctions.pairwise!(K, SqEuclidean(), DX)
6363
@test K pairwise(SqEuclidean(), X; dims=1)
6464
K = zeros(D, D + 1)
65-
pairwise!(K, SqEuclidean(), DX, DY)
65+
KernelFunctions.pairwise!(K, SqEuclidean(), DX, DY)
6666
@test K pairwise(SqEuclidean(), X, Y; dims=1)
6767

6868
let

0 commit comments

Comments
 (0)