Skip to content

Commit c63dc14

Browse files
authored
Make pairwise work with unitful data (#230)
1 parent cb6dcb2 commit c63dc14

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Distances"
22
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
3-
version = "0.10.4"
3+
version = "0.10.5"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/common.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ end
8686

8787
function norm_percol(a::AbstractMatrix{T}) where {T}
8888
n = size(a, 2)
89-
T = typeof(sqrt(oneunit(T)))
90-
r = Vector{√T}(undef, n)
89+
r = Vector{float(T)}(undef, n)
9190
@simd for j in 1:n
9291
aj = view(a, :, j)
9392
r[j] = sqrt(dot(aj, aj))

src/metrics.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -619,22 +619,22 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean},
619619
R = inplace ? mul!(r, a', b) : a'b
620620
sa2 = sum(abs2, a, dims=1)
621621
sb2 = sum(abs2, b, dims=1)
622-
threshT = convert(eltype(r), dist.thresh)
623-
@inbounds if threshT <= 0
622+
= zero(real(eltype(R)))
623+
@inbounds if dist.thresh <= 0
624624
# If there's no chance of triggering the threshold, we can use @simd
625625
for j = 1:nb
626626
sb = sb2[j]
627627
@simd for i = 1:na
628-
r[i, j] = eval_end(dist, (max(sa2[i] + sb - 2real(R[i, j]), 0)))
628+
r[i, j] = eval_end(dist, (max(sa2[i] + sb - 2real(R[i, j]), )))
629629
end
630630
end
631631
else
632632
for j = 1:nb
633633
sb = sb2[j]
634634
for i = 1:na
635635
selfterms = sa2[i] + sb
636-
v = max(selfterms - 2real(R[i, j]), 0)
637-
if v < threshT * selfterms
636+
v = max(selfterms - 2real(R[i, j]), )
637+
if v < dist.thresh * selfterms
638638
# The distance is likely to be inaccurate, recalculate directly
639639
# This reflects the following:
640640
# while sqrt(x+ϵ) ≈ sqrt(x) + O(ϵ/sqrt(x)) when |x| >> ϵ,
@@ -658,22 +658,23 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqEuclidean,Euclidean}, a::Ab
658658
# the following checks if a'*b can be stored in r directly, it fails for complex eltypes
659659
R = inplace ? mul!(r, a', a) : a'a
660660
sa2 = sum(abs2, a, dims=1)
661-
threshT = convert(eltype(r), dist.thresh)
661+
safe = dist.thresh <= 0
662+
= zero(real(eltype(R)))
662663
@inbounds for j = 1:n
663664
for i = 1:(j - 1)
664665
r[i, j] = r[j, i]
665666
end
666-
r[j, j] = 0
667+
r[j, j] = zero(eltype(r))
667668
sa2j = sa2[j]
668-
if threshT <= 0
669+
if safe
669670
@simd for i = (j + 1):n
670-
r[i, j] = eval_end(dist, (max(sa2[i] + sa2j - 2real(R[i, j]), 0)))
671+
r[i, j] = eval_end(dist, (max(sa2[i] + sa2j - 2real(R[i, j]), )))
671672
end
672673
else
673674
for i = (j + 1):n
674675
selfterms = sa2[i] + sa2j
675-
v = max(selfterms - 2real(R[i, j]), 0)
676-
if v < threshT * selfterms
676+
v = max(selfterms - 2real(R[i, j]), )
677+
if v < dist.thresh * selfterms
677678
v = zero(v)
678679
for k = 1:m
679680
v += abs2(a[k, i] - a[k, j])
@@ -698,9 +699,10 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE
698699
# the following checks if a'*b can be stored in r directly, it fails for complex eltypes
699700
inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(b)))) === eltype(r)
700701
R = inplace ? mul!(r, a', w .* b) : a'*Diagonal(w)*b
702+
= zero(real(eltype(R)))
701703
for j = 1:nb
702704
@simd for i = 1:na
703-
@inbounds r[i, j] = eval_end(dist, max(sa2[i] + sb2[j] - 2real(R[i, j]), 0))
705+
@inbounds r[i, j] = eval_end(dist, max(sa2[i] + sb2[j] - 2real(R[i, j]), ))
704706
end
705707
end
706708
r
@@ -715,14 +717,15 @@ function _pairwise!(r::AbstractMatrix, dist::Union{WeightedSqEuclidean,WeightedE
715717
# the following checks if a'*b can be stored in r directly, it fails for complex eltypes
716718
inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(a)))) === eltype(r)
717719
R = inplace ? mul!(r, a', w .* a) : a'*Diagonal(w)*a
720+
= zero(real(eltype(R)))
718721

719722
@inbounds for j = 1:n
720723
for i = 1:(j - 1)
721724
r[i, j] = r[j, i]
722725
end
723-
r[j, j] = 0
726+
r[j, j] = zero(eltype(r))
724727
@simd for i = (j + 1):n
725-
r[i, j] = eval_end(dist, max(sa2[i] + sa2[j] - 2real(R[i, j]), 0))
728+
r[i, j] = eval_end(dist, max(sa2[i] + sa2[j] - 2real(R[i, j]), ))
726729
end
727730
end
728731
r
@@ -734,28 +737,30 @@ function _pairwise!(r::AbstractMatrix, ::CosineDist,
734737
a::AbstractMatrix, b::AbstractMatrix)
735738
require_one_based_indexing(r, a, b)
736739
m, na, nb = get_pairwise_dims(r, a, b)
737-
mul!(r, a', b)
740+
inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(b)))) === eltype(r)
741+
R = inplace ? mul!(r, a', b) : a'b
738742
ra = norm_percol(a)
739743
rb = norm_percol(b)
740744
for j = 1:nb
741745
@simd for i = 1:na
742-
@inbounds r[i, j] = max(1 - r[i, j] / (ra[i] * rb[j]), 0)
746+
@inbounds r[i, j] = max(1 - R[i, j] / (ra[i] * rb[j]), 0)
743747
end
744748
end
745749
r
746750
end
747751
function _pairwise!(r::AbstractMatrix, ::CosineDist, a::AbstractMatrix)
748752
require_one_based_indexing(r, a)
749753
m, n = get_pairwise_dims(r, a)
750-
mul!(r, a', a)
754+
inplace = promote_type(eltype(r), typeof(oneunit(eltype(a))'oneunit(eltype(a)))) === eltype(r)
755+
R = inplace ? mul!(r, a', a) : a'a
751756
ra = norm_percol(a)
752757
@inbounds for j = 1:n
753758
for i = 1:(j - 1)
754759
r[i, j] = r[j, i]
755760
end
756-
r[j, j] = 0
761+
r[j, j] = zero(eltype(r))
757762
@simd for i = j + 1:n
758-
r[i, j] = max(1 - r[i, j] / (ra[i] * ra[j]), 0)
763+
r[i, j] = max(1 - R[i, j] / (ra[i] * ra[j]), 0)
759764
end
760765
end
761766
r

test/test_dists.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,8 @@ end
878878
@test bregman(G, ∇G, p, q) ISdist(p, q)
879879
end
880880

881-
@testset "Unitful vectors" begin
881+
@testset "Unitful data" begin
882+
using Distances, Unitful.DefaultSymbols, Test, LinearAlgebra
882883
x = [1m, 2m, 3m]; y = [2m, 3m, 4m]; w = [1, 1, 1]; p = [2m, 2m, 2m]
883884
@test @inferred sqeuclidean(x, y) == 3m^2
884885
@test @inferred euclidean(x, y) == sqrt(3)m
@@ -903,6 +904,34 @@ end
903904
@test @inferred wminkowski(x, y, w, 2) == euclidean(x, y)
904905
@test @inferred whamming(x, y, w) == hamming(x, y)
905906
@test @inferred peuclidean(x, y, p) == sqrt(3)m
907+
908+
X = [x y]; Y = [y x]
909+
# check specialized pairwise implementations
910+
@test pairwise(Euclidean(), X, dims=2)[1,1] == 0m
911+
@test pairwise(Euclidean(), X, dims=2)[1,2] == sqrt(3)m
912+
@test pairwise(SqEuclidean(), X, dims=2)[1,1] == 0m^2
913+
@test pairwise(SqEuclidean(), X, dims=2)[1,2] == 3m^2
914+
@test pairwise(WeightedEuclidean(w), X, dims=2)[1,1] == 0m
915+
@test pairwise(WeightedEuclidean(w), X, dims=2)[1,2] == sqrt(3)m
916+
@test pairwise(WeightedSqEuclidean(w), X, dims=2)[1,1] == 0m^2
917+
@test pairwise(WeightedSqEuclidean(w), X, dims=2)[1,2] == 3m^2
918+
@test pairwise(Euclidean(), X, Y, dims=2)[1,1] == sqrt(3)m
919+
@test pairwise(Euclidean(), X, Y, dims=2)[1,2] == 0m
920+
@test pairwise(SqEuclidean(), X, Y, dims=2)[1,1] == 3m^2
921+
@test pairwise(SqEuclidean(), X, Y, dims=2)[1,2] == 0m^2
922+
@test pairwise(WeightedEuclidean(w), X, Y, dims=2)[1,1] == sqrt(3)m
923+
@test pairwise(WeightedEuclidean(w), X, Y, dims=2)[1,2] == 0m
924+
@test pairwise(WeightedSqEuclidean(w), X, Y, dims=2)[1,1] == 3m^2
925+
@test pairwise(WeightedSqEuclidean(w), X, Y, dims=2)[1,2] == 0m^2
926+
@test pairwise(CosineDist(), X, dims=2)[1,1] == 0
927+
@test pairwise(CosineDist(), X, dims=2)[1,2] == 1 - dot(x, y) / (norm(x) * norm(y))
928+
@test pairwise(CorrDist(), X, dims=2)[1,1] == 0
929+
@test pairwise(CorrDist(), X, dims=2)[1,2] == cosine_dist(x .- mean(x), y .- mean(y))
930+
# check generic pairwise implementation for one metric
931+
@test pairwise(PeriodicEuclidean(p), X, dims=2)[1,1] == 0m
932+
@test pairwise(PeriodicEuclidean(p), X, dims=2)[1,2] == sqrt(3)m
933+
@test pairwise(PeriodicEuclidean(p), X, Y, dims=2)[1,1] == sqrt(3)m
934+
@test pairwise(PeriodicEuclidean(p), X, Y, dims=2)[1,2] == 0m
906935
end
907936

908937
#=

0 commit comments

Comments
 (0)