Skip to content

Commit 91f51b5

Browse files
authored
Add efficient SparseVector method for some metrics (#235)
1 parent 265a363 commit 91f51b5

File tree

5 files changed

+135
-38
lines changed

5 files changed

+135
-38
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "Distances"
22
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
3-
version = "0.10.6"
3+
version = "0.10.7"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
910

src/Distances.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module Distances
22

33
using LinearAlgebra
44
using Statistics
5+
using SparseArrays: SparseVectorUnion, nonzeroinds, nonzeros, nnz
56
import StatsAPI: pairwise, pairwise!
67

78
export

src/bhattacharyya.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,32 @@ end
5050
return sqab, asum, bsum
5151
end
5252

53+
@inline function _bhattacharyya_coeff(a::SparseVectorUnion{<:Number}, b::SparseVectorUnion{<:Number})
54+
anzind = nonzeroinds(a)
55+
bnzind = nonzeroinds(b)
56+
anzval = nonzeros(a)
57+
bnzval = nonzeros(b)
58+
ma = nnz(a)
59+
mb = nnz(b)
60+
61+
ia = 1; ib = 1
62+
s = zero(typeof(sqrt(oneunit(eltype(a))*oneunit(eltype(b)))))
63+
@inbounds while ia <= ma && ib <= mb
64+
ja = anzind[ia]
65+
jb = bnzind[ib]
66+
if ja == jb
67+
s += sqrt(anzval[ia] * bnzval[ib])
68+
ia += 1; ib += 1
69+
elseif ja < jb
70+
ia += 1
71+
else
72+
ib += 1
73+
end
74+
end
75+
# efficient method for sum for SparseVectorView is missing
76+
return s, sum(anzval), sum(bnzval)
77+
end
78+
5379
# Faster pair- and column-wise versions TBD...
5480

5581

src/metrics.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,55 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b
308308
end
309309
end
310310

311+
eval_op_a(d, ai, b) = eval_op(d, ai, zero(eltype(b)))
312+
eval_op_b(d, bi, a) = eval_op(d, zero(eltype(a)), bi)
313+
314+
# It is assumed that eval_reduce(d, s, eval_op(d, zero(eltype(a)), zero(eltype(b)))) == s
315+
# This justifies ignoring all terms where both inputs are zero.
316+
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::SparseVectorUnion, b::SparseVectorUnion, ::Nothing)
317+
@boundscheck if length(a) != length(b)
318+
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
319+
end
320+
if length(a) == 0
321+
return zero(result_type(d, a, b))
322+
end
323+
anzind = nonzeroinds(a)
324+
bnzind = nonzeroinds(b)
325+
anzval = nonzeros(a)
326+
bnzval = nonzeros(b)
327+
ma = nnz(a)
328+
mb = nnz(b)
329+
ia = 1; ib = 1
330+
s = eval_start(d, a, b)
331+
@inbounds while ia <= ma && ib <= mb
332+
ja = anzind[ia]
333+
jb = bnzind[ib]
334+
if ja == jb
335+
v = eval_op(d, anzval[ia], bnzval[ib])
336+
ia += 1; ib += 1
337+
elseif ja < jb
338+
v = eval_op_a(d, anzval[ia], b)
339+
ia += 1
340+
else
341+
v = eval_op_b(d, bnzval[ib], a)
342+
ib += 1
343+
end
344+
s = eval_reduce(d, s, v)
345+
end
346+
@inbounds while ia <= ma
347+
v = eval_op_a(d, anzval[ia], b)
348+
s = eval_reduce(d, s, v)
349+
ia += 1
350+
end
351+
@inbounds while ib <= mb
352+
v = eval_op_b(d, bnzval[ib], a)
353+
s = eval_reduce(d, s, v)
354+
ib += 1
355+
end
356+
return eval_end(d, s)
357+
end
358+
359+
311360
_evaluate(dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b))
312361
function _evaluate(dist::UnionMetrics, a::Number, b::Number, p)
313362
length(p) != 1 && throw(DimensionMismatch("inputs are scalars but parameters have length $(length(p))."))

test/test_dists.jl

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Unit tests for Distances
22

3+
using SparseArrays: sparsevec, sprand
4+
35
struct FooDist <: PreMetric end # Julia 1.0 Compat: struct definition must be put in global scope
46

57
@testset "result_type" begin
@@ -217,7 +219,7 @@ end
217219
for (_x, _y) in (([4.0, 5.0, 6.0, 7.0], [3.0, 9.0, 8.0, 1.0]),
218220
([4.0, 5.0, 6.0, 7.0], [3. 8.; 9. 1.0]))
219221
x, y = T.(_x), T.(_y)
220-
for (x, y) in ((x, y),
222+
for (x, y) in ((x, y), (sparsevec(x), sparsevec(y)),
221223
(convert(Array{Union{Missing, T}}, x), convert(Array{Union{Missing, T}}, y)),
222224
((Iterators.take(x, 4), Iterators.take(y, 4))), # iterator
223225
(((x[i] for i in 1:length(x)), (y[i] for i in 1:length(y)))), # generator
@@ -331,7 +333,8 @@ end # testset
331333
end #testset
332334

333335
@testset "empty vector" begin
334-
for T in (Float64, F64), (a, b) in ((T[], T[]), (Iterators.take(T[], 0), Iterators.take(T[], 0)))
336+
for T in (Float64, F64), (a, b) in ((T[], T[]), (Iterators.take(T[], 0), Iterators.take(T[], 0)),
337+
(sprand(T, 0, .1), sprand(T, 0, .1)))
335338
@test sqeuclidean(a, b) == 0.0
336339
@test isa(sqeuclidean(a, b), T)
337340
@test euclidean(a, b) == 0.0
@@ -391,6 +394,10 @@ end # testset
391394
@test_throws DimensionMismatch colwise!(mat23, Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), mat23, mat22)
392395
@test_throws DimensionMismatch Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x)([1, 2, 3], [1, 2])
393396
@test_throws DimensionMismatch Bregman(x -> sqeuclidean(x, zero(x)), x -> [1, 2])([1, 2, 3], [1, 2, 3])
397+
sv1 = sprand(10, .2)
398+
sv2 = sprand(20, .2)
399+
@test_throws DimensionMismatch euclidean(sv1, sv2)
400+
@test_throws DimensionMismatch bhattacharyya(sv1, sv2)
394401
end # testset
395402

396403
@testset "Different input types" begin
@@ -504,41 +511,43 @@ end
504511

505512
@testset "bhattacharyya / hellinger" begin
506513
for T in (Int, Float64, F64)
507-
x, y = T.([4, 5, 6, 7]), T.([3, 9, 8, 1])
508-
a = T.([1, 2, 1, 3, 2, 1])
509-
b = T.([1, 3, 0, 2, 2, 0])
510-
p = T == Int ? rand(0:10, 12) : rand(T, 12)
511-
p[p .< median(p)] .= 0
512-
q = T == Int ? rand(0:10, 12) : rand(T, 12)
513-
514-
# Bhattacharyya and Hellinger distances are defined for discrete
515-
# probability distributions so to calculate the expected values
516-
# we need to normalize vectors.
517-
px = x ./ sum(x)
518-
py = y ./ sum(y)
519-
expected_bc_x_y = sum(sqrt.(px .* py))
520-
for (x, y) in ((x, y), (Iterators.take(x, 12), Iterators.take(y, 12)))
521-
@test Distances.bhattacharyya_coeff(x, y) expected_bc_x_y
522-
@test bhattacharyya(x, y) (-log(expected_bc_x_y))
523-
@test hellinger(x, y) sqrt(1 - expected_bc_x_y)
524-
end
514+
_x, _y = T.([4, 5, 6, 7]), T.([3, 9, 8, 1])
515+
_a = T.([1, 2, 1, 3, 2, 1])
516+
_b = T.([1, 3, 0, 2, 2, 0])
517+
_p = T == Int ? rand(0:10, 12) : rand(T, 12)
518+
_p[_p .< median(_p)] .= 0
519+
_q = T == Int ? rand(0:10, 12) : rand(T, 12)
520+
521+
for (x, y, a, b, p, q) in ((_x, _y, _a, _b, _p, _q), sparsevec.((_x, _y, _a, _b, _p, _q)))
522+
# Bhattacharyya and Hellinger distances are defined for discrete
523+
# probability distributions so to calculate the expected values
524+
# we need to normalize vectors.
525+
px = x ./ sum(x)
526+
py = y ./ sum(y)
527+
expected_bc_x_y = sum(sqrt.(px .* py))
528+
for (x, y) in ((x, y), (Iterators.take(x, 12), Iterators.take(y, 12)))
529+
@test Distances.bhattacharyya_coeff(x, y) expected_bc_x_y
530+
@test bhattacharyya(x, y) (-log(expected_bc_x_y))
531+
@test hellinger(x, y) sqrt(1 - expected_bc_x_y)
532+
end
525533

526-
pa = a ./ sum(a)
527-
pb = b ./ sum(b)
528-
expected_bc_a_b = sum(sqrt.(pa .* pb))
529-
@test Distances.bhattacharyya_coeff(a, b) expected_bc_a_b
530-
@test bhattacharyya(a, b) (-log(expected_bc_a_b))
531-
@test hellinger(a, b) sqrt(1 - expected_bc_a_b)
532-
533-
pp = p ./ sum(p)
534-
pq = q ./ sum(q)
535-
expected_bc_p_q = sum(sqrt.(pp .* pq))
536-
@test Distances.bhattacharyya_coeff(p, q) expected_bc_p_q
537-
@test bhattacharyya(p, q) (-log(expected_bc_p_q))
538-
@test hellinger(p, q) sqrt(1 - expected_bc_p_q)
539-
540-
# Ensure it is semimetric
541-
@test bhattacharyya(x, y) bhattacharyya(y, x)
534+
pa = a ./ sum(a)
535+
pb = b ./ sum(b)
536+
expected_bc_a_b = sum(sqrt.(pa .* pb))
537+
@test Distances.bhattacharyya_coeff(a, b) expected_bc_a_b
538+
@test bhattacharyya(a, b) (-log(expected_bc_a_b))
539+
@test hellinger(a, b) sqrt(1 - expected_bc_a_b)
540+
541+
pp = p ./ sum(p)
542+
pq = q ./ sum(q)
543+
expected_bc_p_q = sum(sqrt.(pp .* pq))
544+
@test Distances.bhattacharyya_coeff(p, q) expected_bc_p_q
545+
@test bhattacharyya(p, q) (-log(expected_bc_p_q))
546+
@test hellinger(p, q) sqrt(1 - expected_bc_p_q)
547+
548+
# Ensure it is semimetric
549+
@test bhattacharyya(x, y) bhattacharyya(y, x)
550+
end
542551
end
543552
end #testset
544553

@@ -769,7 +778,7 @@ end
769778

770779
X = rand(ComplexF64, m, nx)
771780
Y = rand(ComplexF64, m, ny)
772-
781+
773782
test_pairwise(SqEuclidean(), X, Y, Float64)
774783
test_pairwise(Euclidean(), X, Y, Float64)
775784

@@ -946,6 +955,17 @@ end
946955
@test pairwise(PeriodicEuclidean(p), X, Y, dims=2)[1,2] == 0m
947956
end
948957

958+
@testset "SparseVector, nnz(a) != nnz(b)" begin
959+
for (n, densa, densb) in ((100, .1, .8), (200, .8, .1))
960+
a = sprand(n, densa)
961+
b = sprand(n, densb)
962+
for d in (bhattacharyya, euclidean, sqeuclidean, jaccard, cityblock, totalvariation,
963+
chebyshev, braycurtis, hamming)
964+
@test d(a, b) d(Vector(a), Vector(b))
965+
end
966+
end
967+
end
968+
949969
#=
950970
@testset "zero allocation colwise!" begin
951971
d = Euclidean()

0 commit comments

Comments
 (0)