Skip to content

Commit 9d09e91

Browse files
timholyKristofferC
authored andcommitted
Add a precision threshold to Euclidean and SqEuclidean (#63)
If a matrix contains duplicated columns, often the distance between identical points (which should be 1) is of order 1e-8 due to the fact that sqrt(roundofferror) ~ 1e-8. This changes the behavior of Euclidean to recalculate the distance by direct subtraction when the points are close compared to their magnitudes.
1 parent c6b529d commit 9d09e91

File tree

3 files changed

+165
-23
lines changed

3 files changed

+165
-23
lines changed

README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,32 @@ Each distance corresponds to a distance type. The type name and the correspondin
154154

155155
**Note:** The formulas above are using *Julia*'s functions. These formulas are mainly for conveying the math concepts in a concise way. The actual implementation may use a faster way.
156156

157+
### Precision for Euclidean and SqEuclidean
158+
159+
For efficiency (see the benchmarks below), `Euclidean` and
160+
`SqEuclidean` make use of BLAS3 matrix-matrix multiplication to
161+
calculate distances. This corresponds to the following expansion:
162+
163+
```julia
164+
(x-y)^2 == x^2 - 2xy + y^2
165+
```
166+
167+
However, equality is not precise in the presence of roundoff error,
168+
and particularly when `x` and `y` are nearby points this may not be
169+
accurate. Consequently, `Euclidean` and `SqEuclidean` allow you to
170+
supply a relative tolerance to force recalculation:
171+
172+
```julia
173+
julia> x = reshape([0.1, 0.3, -0.1], 3, 1);
174+
175+
julia> pairwise(Euclidean(), x, x)
176+
1×1 Array{Float64,2}:
177+
7.45058e-9
178+
179+
julia> pairwise(Euclidean(1e-12), x, x)
180+
1×1 Array{Float64,2}:
181+
0.0
182+
```
157183

158184
## Benchmarks
159185

@@ -215,5 +241,3 @@ The table below compares the performance (measured in terms of average elapsed t
215241
| Mahalanobis | 0.373796 | 0.002359 | **158.4337** |
216242

217243
For distances of which a major part of the computation is a quadratic form (e.g. *Euclidean*, *CosineDist*, *Mahalanobis*), the performance can be drastically improved by restructuring the computation and delegating the core part to ``GEMM`` in *BLAS*. The use of this strategy can easily lead to 100x performance gain over simple loops (see the highlighted part of the table above).
218-
219-

src/metrics.jl

Lines changed: 118 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
#
77
###########################################################
88

9-
type Euclidean <: Metric end
10-
type SqEuclidean <: SemiMetric end
9+
immutable Euclidean <: Metric
10+
thresh::Float64
11+
end
12+
immutable SqEuclidean <: SemiMetric
13+
thresh::Float64
14+
end
1115
type Chebyshev <: Metric end
1216
type Cityblock <: Metric end
1317
type Jaccard <: Metric end
@@ -53,6 +57,44 @@ type SpanNormDist <: SemiMetric end
5357

5458
typealias UnionMetrics Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist}
5559

60+
"""
61+
Euclidean([thresh])
62+
63+
Create a euclidean metric.
64+
65+
When computing distances among large numbers of points, it can be much
66+
more efficient to exploit the formula
67+
68+
(x-y)^2 = x^2 - 2xy + y^2
69+
70+
However, this can introduce roundoff error. `thresh` (which defaults
71+
to 0) specifies the relative square-distance tolerance on `2xy`
72+
compared to `x^2 + y^2` to force recalculation of the distance using
73+
the more precise direct (elementwise-subtraction) formula.
74+
75+
# Example:
76+
```julia
77+
julia> x = reshape([0.1, 0.3, -0.1], 3, 1);
78+
79+
julia> pairwise(Euclidean(), x, x)
80+
1×1 Array{Float64,2}:
81+
7.45058e-9
82+
83+
julia> pairwise(Euclidean(1e-12), x, x)
84+
1×1 Array{Float64,2}:
85+
0.0
86+
```
87+
"""
88+
Euclidean() = Euclidean(0)
89+
90+
"""
91+
SqEuclidean([thresh])
92+
93+
Create a squared-euclidean semi-metric. For the meaning of `thresh`,
94+
see [`Euclidean`](@ref).
95+
"""
96+
SqEuclidean() = SqEuclidean(0)
97+
5698
###########################################################
5799
#
58100
# Define Evaluate
@@ -289,6 +331,7 @@ end
289331
end
290332
rogerstanimoto{T <: Bool}(a::AbstractArray{T}, b::AbstractArray{T}) = evaluate(RogersTanimoto(), a, b)
291333

334+
292335
###########################################################
293336
#
294337
# Special method
@@ -300,28 +343,65 @@ function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix, b::A
300343
At_mul_B!(r, a, b)
301344
sa2 = sumabs2(a, 1)
302345
sb2 = sumabs2(b, 1)
303-
pdist!(r, sa2, sb2)
304-
end
305-
function pdist!(r, sa2, sb2)
306-
for j = 1 : size(r,2)
307-
sb = sb2[j]
308-
@simd for i = 1 : size(r,1)
309-
@inbounds r[i,j] = sa2[i] + sb - 2 * r[i,j]
346+
threshT = convert(eltype(r), dist.thresh)
347+
if threshT <= 0
348+
# If there's no chance of triggering the threshold, we can use @simd
349+
for j = 1 : size(r,2)
350+
sb = sb2[j]
351+
@simd for i = 1 : size(r,1)
352+
@inbounds r[i,j] = sa2[i] + sb - 2 * r[i,j]
353+
end
354+
end
355+
else
356+
for j = 1 : size(r,2)
357+
sb = sb2[j]
358+
for i = 1 : size(r,1)
359+
@inbounds selfterms = sa2[i] + sb
360+
@inbounds v = selfterms - 2*r[i,j]
361+
if v < threshT*selfterms
362+
# The distance is likely to be inaccurate, recalculate at higher prec.
363+
# This reflects the following:
364+
# ((x+ϵ) - y)^2 ≈ x^2 - 2xy + y^2 + O(ϵ) when |x-y| >> ϵ
365+
# ((x+ϵ) - y)^2 ≈ O(ϵ^2) otherwise
366+
v = zero(v)
367+
for k = 1:size(a,1)
368+
@inbounds v += (a[k,i]-b[k,j])^2
369+
end
370+
end
371+
@inbounds r[i,j] = v
372+
end
310373
end
311374
end
312375
r
313376
end
377+
314378
function pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
315379
m, n = get_pairwise_dims(r, a)
316380
At_mul_B!(r, a, a)
317381
sa2 = sumsq_percol(a)
382+
threshT = convert(eltype(r), dist.thresh)
318383
@inbounds for j = 1 : n
319384
for i = 1 : j-1
320385
r[i,j] = r[j,i]
321386
end
322387
r[j,j] = 0
323-
for i = j+1 : n
324-
r[i,j] = sa2[i] + sa2[j] - 2 * r[i,j]
388+
sa2j = sa2[j]
389+
if threshT <= 0
390+
@simd for i = j+1 : n
391+
r[i,j] = sa2[i] + sa2j - 2 * r[i,j]
392+
end
393+
else
394+
for i = j+1 : n
395+
selfterms = sa2[i] + sa2j
396+
v = selfterms - 2*r[i,j]
397+
if v < threshT*selfterms
398+
v = zero(v)
399+
for k = 1:size(a,1)
400+
v += (a[k,i]-a[k,j])^2
401+
end
402+
end
403+
r[i,j] = v
404+
end
325405
end
326406
end
327407
r
@@ -333,10 +413,23 @@ function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix, b::Abs
333413
At_mul_B!(r, a, b)
334414
sa2 = sumsq_percol(a)
335415
sb2 = sumsq_percol(b)
416+
threshT = convert(eltype(r), dist.thresh)
336417
@inbounds for j = 1 : nb
418+
sb = sb2[j]
337419
for i = 1 : na
338-
v = sa2[i] + sb2[j] - 2 * r[i,j]
339-
r[i,j] = isnan(v) ? NaN : sqrt(max(v, 0.))
420+
selfterms = sa2[i] + sb
421+
v = selfterms - 2*r[i,j]
422+
if v < threshT*selfterms
423+
# The distance is likely to be inaccurate, recalculate directly
424+
# This reflects the following:
425+
# while sqrt(x+ϵ) ≈ sqrt(x) + O(ϵ/sqrt(x)) when |x| >> ϵ,
426+
# sqrt(x+ϵ) ≈ O(sqrt(ϵ)) otherwise.
427+
v = zero(v)
428+
for k = 1:m
429+
v += (a[k,i]-b[k,j])^2
430+
end
431+
end
432+
r[i,j] = sqrt(v)
340433
end
341434
end
342435
r
@@ -346,14 +439,23 @@ function pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
346439
m, n = get_pairwise_dims(r, a)
347440
At_mul_B!(r, a, a)
348441
sa2 = sumsq_percol(a)
442+
threshT = convert(eltype(r), dist.thresh)
349443
@inbounds for j = 1 : n
350444
for i = 1 : j-1
351445
r[i,j] = r[j,i]
352446
end
353-
@inbounds r[j,j] = 0
447+
r[j,j] = 0
448+
sa2j = sa2[j]
354449
for i = j+1 : n
355-
v = sa2[i] + sa2[j] - 2 * r[i,j]
356-
r[i,j] = isnan(v) ? NaN : sqrt(max(v, 0.))
450+
selfterms = sa2[i] + sa2j
451+
v = selfterms - 2*r[i,j]
452+
if v < threshT*selfterms
453+
v = zero(v)
454+
for k = 1:m
455+
v += (a[k,i]-a[k,j])^2
456+
end
457+
end
458+
r[i,j] = sqrt(v)
357459
end
358460
end
359461
r

test/test_dists.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ w = rand(size(a))
103103
p = r = rand(12)
104104
p[p .< 0.3] = 0.0
105105
scale = sum(p) / sum(r)
106-
r /= sum(r)
106+
r /= sum(r)
107107
p /= sum(p)
108108
q = rand(12)
109109
q /= sum(q)
@@ -121,14 +121,14 @@ end
121121
@test renyi_divergence(p, p, rand()) 0
122122
@test renyi_divergence(p, p, 1.0 + rand()) 0
123123
@test renyi_divergence(p, p, Inf) 0
124-
@test renyi_divergence(p, r, 0) -log(scale)
125-
@test renyi_divergence(p, r, 1) -log(scale)
126-
@test renyi_divergence(p, r, rand()) -log(scale)
124+
@test renyi_divergence(p, r, 0) -log(scale)
125+
@test renyi_divergence(p, r, 1) -log(scale)
126+
@test renyi_divergence(p, r, rand()) -log(scale)
127127
@test renyi_divergence(p, r, Inf) -log(scale)
128128
@test isinf(renyi_divergence([0.0, 0.5, 0.5], [0.0, 1.0, 0.0], Inf))
129129
@test renyi_divergence([0.0, 1.0, 0.0], [0.0, 0.5, 0.5], Inf) log(2.0)
130130
@test renyi_divergence(p, q, 1) kl_divergence(p, q)
131-
131+
132132
pm = (p + q) / 2
133133
jsv = kl_divergence(p, pm) / 2 + kl_divergence(q, pm) / 2
134134
@test js_divergence(p, p) 0.0
@@ -385,3 +385,19 @@ Q = Q * Q' # make sure Q is positive-definite
385385
@test_pairwise Mahalanobis(Q) X Y
386386

387387
end #testset
388+
389+
@testset "Euclidean precision" begin
390+
X = [0.1 0.2; 0.3 0.4; -0.1 -0.1]
391+
pd = pairwise(Euclidean(1e-12), X, X)
392+
@test pd[1,1] == 0
393+
@test pd[2,2] == 0
394+
pd = pairwise(Euclidean(1e-12), X)
395+
@test pd[1,1] == 0
396+
@test pd[2,2] == 0
397+
pd = pairwise(SqEuclidean(1e-12), X, X)
398+
@test pd[1,1] == 0
399+
@test pd[2,2] == 0
400+
pd = pairwise(SqEuclidean(1e-12), X)
401+
@test pd[1,1] == 0
402+
@test pd[2,2] == 0
403+
end

0 commit comments

Comments
 (0)