Skip to content

Commit 3264a92

Browse files
authored
Fix 346 (#347)
* Add methods for pairwise * Bump patch * Bump patch
1 parent d381a68 commit 3264a92

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.10"
3+
version = "0.10.11"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs)
165165
return Distances.pairwise!(out, d, x.X, y.X; dims=1)
166166
end
167167

168+
# Resolve ambiguity error for ColVecs vs RowVecs. #346
169+
pairwise(d::PreMetric, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X)))
170+
pairwise(d::PreMetric, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y)
171+
168172
dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
169173
dim(x::AbstractVector) = dim(first(x))
170174
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))

test/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@
9393
@test back(ones(size(X)))[1].X == ones(size(X))
9494
end
9595
end
96+
@testset "ColVecs + RowVecs" begin
97+
x_colvecs = ColVecs(randn(3, 5))
98+
x_rowvecs = RowVecs(randn(7, 3))
99+
100+
@test isapprox(
101+
pairwise(SqEuclidean(), x_colvecs, x_rowvecs),
102+
pairwise(SqEuclidean(), collect(x_colvecs), collect(x_rowvecs)),
103+
)
104+
@test isapprox(
105+
pairwise(SqEuclidean(), x_rowvecs, x_colvecs),
106+
pairwise(SqEuclidean(), collect(x_rowvecs), collect(x_colvecs)),
107+
)
108+
end
96109
@testset "input checks" begin
97110
D = 3
98111
D⁻ = 2

0 commit comments

Comments
 (0)