Skip to content

Commit 11c2056

Browse files
dkarraschKristofferC
authored andcommitted
add periodic distance, second attempt (#129)
* add PeriodicEuclidean and tests * add new distance to README and benchmarks * remove stupid "fallback" * update benchmark tables in README
1 parent bb54e3b commit 11c2056

File tree

7 files changed

+171
-66
lines changed

7 files changed

+171
-66
lines changed

README.md

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ This package also provides optimized functions to compute column-wise and pairwi
1414

1515
* Euclidean distance
1616
* Squared Euclidean distance
17+
* Periodic Euclidean distance
1718
* Cityblock distance
1819
* Total variation distance
1920
* Jaccard distance
@@ -38,12 +39,12 @@ This package also provides optimized functions to compute column-wise and pairwi
3839
* Root mean squared deviation
3940
* Normalized root mean squared deviation
4041
* Bray-Curtis dissimilarity
41-
* Bregman divergence
42+
* Bregman divergence
4243

4344
For ``Euclidean distance``, ``Squared Euclidean distance``, ``Cityblock distance``, ``Minkowski distance``, and ``Hamming distance``, a weighted version is also provided.
4445

4546

46-
## Basic Use
47+
## Basic use
4748

4849
The library supports three ways of computation: *computing the distance between two vectors*, *column-wise computation*, and *pairwise computation*.
4950

@@ -140,6 +141,7 @@ Each distance corresponds to a distance type. The type name and the correspondin
140141
| -------------------- | -------------------------- | --------------------|
141142
| Euclidean | `euclidean(x, y)` | `sqrt(sum((x - y) .^ 2))` |
142143
| SqEuclidean | `sqeuclidean(x, y)` | `sum((x - y).^2)` |
144+
| PeriodicEuclidean | `peuclidean(x, y, p)` | `sqrt(sum(min(mod(abs(x - y), p), p - mod(abs(x - y), p)).^2))` |
143145
| Cityblock | `cityblock(x, y)` | `sum(abs(x - y))` |
144146
| TotalVariation | `totalvariation(x, y)` | `sum(abs(x - y)) / 2` |
145147
| Chebyshev | `chebyshev(x, y)` | `max(abs(x - y))` |
@@ -170,7 +172,7 @@ Each distance corresponds to a distance type. The type name and the correspondin
170172
| WeightedCityblock | `wcityblock(x, y, w)` | `sum(abs(x - y) .* w)` |
171173
| WeightedMinkowski | `wminkowski(x, y, w, p)` | `sum(abs(x - y).^p .* w) ^ (1/p)` |
172174
| WeightedHamming | `whamming(x, y, w)` | `sum((x .!= y) .* w)` |
173-
| Bregman | `bregman(F, ∇, x, y; inner = LinearAlgebra.dot)` | `F(x) - F(y) - inner(∇(y), x - y)` |
175+
| Bregman | `bregman(F, ∇, x, y; inner = LinearAlgebra.dot)` | `F(x) - F(y) - inner(∇(y), x - y)` |
174176

175177
**Note:** The formulas above are using *Julia*'s functions. These formulas are mainly for conveying the math concepts in a concise way. The actual implementation may use a faster way. The arguments `x` and `y` are arrays of real numbers; `k` and `l` are arrays of distinct elements of any kind; a and b are arrays of Bools; and finally, `p` and `q` are arrays forming a discrete probability distribution and are therefore both expected to sum to one.
176178

@@ -207,7 +209,7 @@ julia> pairwise(Euclidean(1e-12), x, x)
207209
The implementation has been carefully optimized based on benchmarks. The script in `benchmark/benchmarks.jl` defines a benchmark suite
208210
for a variety of distances, under column-wise and pairwise settings.
209211

210-
Here are benchmarks obtained running Julia 0.6 on a computer with a quad-core Intel Core i5-2500K processor @ 3.3 GHz.
212+
Here are benchmarks obtained running Julia 1.0 on a computer with a dual-core Intel Core i5-2300K processor @ 2.3 GHz.
211213
The tables below can be replicated using the script in `benchmark/print_table.jl`.
212214

213215
#### Column-wise benchmark
@@ -216,29 +218,31 @@ The table below compares the performance (measured in terms of average elapsed t
216218

217219
| distance | loop | colwise | gain |
218220
|----------- | -------| ----------| -------|
219-
| SqEuclidean | 0.005460s | 0.001676s | 3.2582 |
220-
| Euclidean | 0.005513s | 0.001681s | 3.2792 |
221-
| Cityblock | 0.005409s | 0.001675s | 3.2292 |
222-
| Chebyshev | 0.008592s | 0.004575s | 1.8779 |
223-
| Minkowski | 0.056741s | 0.048808s | 1.1625 |
224-
| Hamming | 0.005320s | 0.001670s | 3.1847 |
225-
| CosineDist | 0.005663s | 0.001697s | 3.3378 |
226-
| CorrDist | 0.010000s | 0.013904s | 0.7192 |
227-
| ChiSqDist | 0.009626s | 0.004734s | 2.0333 |
228-
| KLDivergence | 0.046696s | 0.035091s | 1.3307 |
229-
| RenyiDivergence | 0.021123s | 0.012006s | 1.7594 |
230-
| RenyiDivergence | 0.080503s | 0.066987s | 1.2018 |
231-
| JSDivergence | 0.066404s | 0.059564s | 1.1148 |
232-
| BhattacharyyaDist | 0.013065s | 0.008807s | 1.4836 |
233-
| HellingerDist | 0.013013s | 0.008679s | 1.4993 |
234-
| WeightedSqEuclidean | 0.005534s | 0.001676s | 3.3028 |
235-
| WeightedEuclidean | 0.005601s | 0.001723s | 3.2513 |
236-
| WeightedCityblock | 0.005496s | 0.001675s | 3.2815 |
237-
| WeightedMinkowski | 0.057847s | 0.051389s | 1.1257 |
238-
| WeightedHamming | 0.005439s | 0.001673s | 3.2513 |
239-
| SqMahalanobis | 0.134717s | 0.019530s | 6.8980 |
240-
| Mahalanobis | 0.129455s | 0.020114s | 6.4361 |
241-
| BrayCurtis | 0.005666s | 0.001680s | 3.3736 |
221+
| SqEuclidean | 0.004432s | 0.001049s | 4.2270 |
222+
| Euclidean | 0.004537s | 0.001054s | 4.3031 |
223+
| PeriodicEuclidean | 0.012092s | 0.006714s | 1.8011 |
224+
| Cityblock | 0.004515s | 0.001060s | 4.2585 |
225+
| TotalVariation | 0.004496s | 0.001062s | 4.2337 |
226+
| Chebyshev | 0.009123s | 0.005034s | 1.8123 |
227+
| Minkowski | 0.047573s | 0.042508s | 1.1191 |
228+
| Hamming | 0.004355s | 0.001099s | 3.9638 |
229+
| CosineDist | 0.006432s | 0.002282s | 2.8185 |
230+
| CorrDist | 0.010273s | 0.012500s | 0.8219 |
231+
| ChiSqDist | 0.005291s | 0.001271s | 4.1635 |
232+
| KLDivergence | 0.031491s | 0.025643s | 1.2281 |
233+
| RenyiDivergence | 0.052420s | 0.048075s | 1.0904 |
234+
| RenyiDivergence | 0.017317s | 0.009023s | 1.9193 |
235+
| JSDivergence | 0.047905s | 0.044006s | 1.0886 |
236+
| BhattacharyyaDist | 0.007761s | 0.003796s | 2.0445 |
237+
| HellingerDist | 0.007636s | 0.003665s | 2.0836 |
238+
| WeightedSqEuclidean | 0.004550s | 0.001151s | 3.9541 |
239+
| WeightedEuclidean | 0.004687s | 0.001168s | 4.0125 |
240+
| WeightedCityblock | 0.004493s | 0.001157s | 3.8849 |
241+
| WeightedMinkowski | 0.049442s | 0.042145s | 1.1732 |
242+
| WeightedHamming | 0.004431s | 0.001153s | 3.8440 |
243+
| SqMahalanobis | 0.082493s | 0.019843s | 4.1574 |
244+
| Mahalanobis | 0.082180s | 0.019618s | 4.1891 |
245+
| BrayCurtis | 0.004464s | 0.001121s | 3.9809 |
242246

243247
We can see that using ``colwise`` instead of a simple loop yields considerable gain (2x - 4x), especially when the internal computation of each distance is simple. Nonetheless, when the computation of a single distance is heavy enough (e.g. *KLDivergence*, *RenyiDivergence*), the gain is not as significant.
244248

@@ -248,28 +252,30 @@ The table below compares the performance (measured in terms of average elapsed t
248252

249253
| distance | loop | pairwise | gain |
250254
|----------- | -------| ----------| -------|
251-
| SqEuclidean | 0.015116s | 0.000192s | **78.7747** |
252-
| Euclidean | 0.015565s | 0.000390s | 39.8829 |
253-
| Cityblock | 0.015048s | 0.001400s | 10.7469 |
254-
| Chebyshev | 0.023325s | 0.010921s | 2.1358 |
255-
| Minkowski | 0.143427s | 0.121050s | 1.1849 |
256-
| Hamming | 0.015191s | 0.001334s | 11.3856 |
257-
| CosineDist | 0.016688s | 0.000393s | **42.5158** |
258-
| CorrDist | 0.029024s | 0.000435s | **66.7043** |
259-
| ChiSqDist | 0.026035s | 0.012194s | 2.1351 |
260-
| KLDivergence | 0.115800s | 0.086968s | 1.3315 |
261-
| RenyiDivergence | 0.055551s | 0.029628s | 1.8749 |
262-
| RenyiDivergence | 0.205270s | 0.163031s | 1.2591 |
263-
| JSDivergence | 0.165078s | 0.148902s | 1.1086 |
264-
| BhattacharyyaDist | 0.035493s | 0.022429s | 1.5824 |
265-
| HellingerDist | 0.035028s | 0.021867s | 1.6019 |
266-
| WeightedSqEuclidean | 0.016330s | 0.000276s | **59.2117** |
267-
| WeightedEuclidean | 0.016600s | 0.000508s | **32.6478** |
268-
| WeightedCityblock | 0.015604s | 0.001816s | 8.5913 |
269-
| WeightedMinkowski | 0.159052s | 0.128427s | 1.2385 |
270-
| WeightedHamming | 0.015212s | 0.001634s | 9.3110 |
271-
| SqMahalanobis | 0.607881s | 0.000365s | **1665.3228** |
272-
| Mahalanobis | 0.623032s | 0.000604s | **1031.9581** |
273-
| BrayCurtis | 0.015843s | 0.002273s | 6.9695 |
255+
| SqEuclidean | 0.012498s | 0.000170s | **73.6596** |
256+
| Euclidean | 0.012583s | 0.000257s | 48.9628 |
257+
| PeriodicEuclidean | 0.030935s | 0.017572s | 1.7605 |
258+
| Cityblock | 0.012416s | 0.000910s | 13.6464 |
259+
| TotalVariation | 0.012763s | 0.000959s | 13.3080 |
260+
| Chebyshev | 0.023800s | 0.012042s | 1.9763 |
261+
| Minkowski | 0.121388s | 0.107333s | 1.1310 |
262+
| Hamming | 0.012171s | 0.000689s | 17.6538 |
263+
| CosineDist | 0.017474s | 0.000214s | **81.6546** |
264+
| CorrDist | 0.028195s | 0.000259s | **108.7360** |
265+
| ChiSqDist | 0.014372s | 0.003129s | 4.5932 |
266+
| KLDivergence | 0.079669s | 0.063491s | 1.2548 |
267+
| RenyiDivergence | 0.134093s | 0.117737s | 1.1389 |
268+
| RenyiDivergence | 0.047658s | 0.024960s | 1.9094 |
269+
| JSDivergence | 0.121999s | 0.110984s | 1.0993 |
270+
| BhattacharyyaDist | 0.021788s | 0.009414s | 2.3145 |
271+
| HellingerDist | 0.020735s | 0.008784s | 2.3606 |
272+
| WeightedSqEuclidean | 0.012671s | 0.000186s | **68.0345** |
273+
| WeightedEuclidean | 0.012867s | 0.000276s | **46.6634** |
274+
| WeightedCityblock | 0.012803s | 0.001539s | 8.3200 |
275+
| WeightedMinkowski | 0.127386s | 0.107257s | 1.1877 |
276+
| WeightedHamming | 0.012240s | 0.001462s | 8.3747 |
277+
| SqMahalanobis | 0.214285s | 0.000330s | **650.0722** |
278+
| Mahalanobis | 0.197294s | 0.000420s | **470.2354** |
279+
| BrayCurtis | 0.012872s | 0.001489s | 8.6456 |
274280

275281
For distances of which a major part of the computation is a quadratic form (e.g. *Euclidean*, *CosineDist*, *Mahalanobis*), the performance can be drastically improved by restructuring the computation and delegating the core part to ``GEMM`` in *BLAS*. The use of this strategy can easily lead to 100x performance gain over simple loops (see the highlighted part of the table above).

benchmark/benchmarks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ function create_distances(w, Q)
77
dists = [
88
SqEuclidean(),
99
Euclidean(),
10+
PeriodicEuclidean(w),
1011
Cityblock(),
1112
TotalVariation(),
1213
Chebyshev(),
@@ -144,7 +145,7 @@ function add_pairwise_benchmarks!(SUITE)
144145
Tdist = typeof(dist)
145146
SUITE["pairwise"][Tdist] = BenchmarkGroup()
146147
SUITE["pairwise"][Tdist]["loop"] = @benchmarkable evaluate_pairwise($dist, $a, $b)
147-
SUITE["pairwise"][Tdist]["specialized"] = @benchmarkable pairwise($dist, $a, $b)
148+
SUITE["pairwise"][Tdist]["specialized"] = @benchmarkable pairwise($dist, $a, $b; dims=2)
148149
end
149150
end
150151
end

benchmark/print_table.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include("benchmarks.jl")
1010
order = [
1111
:SqEuclidean,
1212
:Euclidean,
13+
:PeriodicEuclidean,
1314
:Cityblock,
1415
:TotalVariation,
1516
:Chebyshev,
@@ -66,7 +67,7 @@ function print_table(judgement)
6667
print(io, "| ", getname(dist), " |")
6768
print(io, @sprintf("%9.6fs | %9.6fs | %7.4f |\n", t_loop / 1e9, t_spec / 1e9, (t_loop / t_spec)))
6869
end
69-
print(STDOUT, String(take!(io)))
70+
print(stdout, String(take!(io)))
7071
println()
7172
end
7273
end

src/Distances.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ export
2222
# distance classes
2323
Euclidean,
2424
SqEuclidean,
25+
PeriodicEuclidean,
2526
Cityblock,
2627
TotalVariation,
2728
Chebyshev,
@@ -61,6 +62,7 @@ export
6162
# convenient functions
6263
euclidean,
6364
sqeuclidean,
65+
peuclidean,
6466
cityblock,
6567
totalvariation,
6668
jaccard,

src/metrics.jl

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,11 @@ struct MeanSqDeviation <: SemiMetric end
9999
struct RMSDeviation <: Metric end
100100
struct NormRMSDeviation <: Metric end
101101

102+
struct PeriodicEuclidean{W <: AbstractArray{<: Real}} <: Metric
103+
periods::W
104+
end
102105

103-
const UnionMetrics = Union{Euclidean,SqEuclidean,Chebyshev,Cityblock,TotalVariation,Minkowski,Hamming,Jaccard,RogersTanimoto,CosineDist,CorrDist,ChiSqDist,KLDivergence,RenyiDivergence,BrayCurtis,JSDivergence,SpanNormDist,GenKLDivergence}
106+
const UnionMetrics = Union{Euclidean,SqEuclidean,PeriodicEuclidean,Chebyshev,Cityblock,TotalVariation,Minkowski,Hamming,Jaccard,RogersTanimoto,CosineDist,CorrDist,ChiSqDist,KLDivergence,RenyiDivergence,BrayCurtis,JSDivergence,SpanNormDist,GenKLDivergence}
104107

105108
"""
106109
Euclidean([thresh])
@@ -140,6 +143,26 @@ see [`Euclidean`](@ref).
140143
"""
141144
SqEuclidean() = SqEuclidean(0)
142145

146+
"""
147+
PeriodicEuclidean(L)
148+
149+
Create a Euclidean metric on a rectangular periodic domain (i.e., a torus or
150+
a cylinder). Periods per dimension are contained in the vector `L`:
151+
```math
152+
\\sqrt{\\sum_i(\\min\\mod(|x_i - y_i|, p), p - \\mod(|x_i - y_i|, p))^2}.
153+
```
154+
For dimensions without periodicity put `Inf` in the respective component.
155+
156+
# Example
157+
```jldoctest
158+
julia> x, y, L = [0.0, 0.0], [0.75, 0.0], [0.5, Inf];
159+
160+
julia> evaluate(PeriodicEuclidean(L), x, y)
161+
0.25
162+
```
163+
"""
164+
PeriodicEuclidean() = PeriodicEuclidean(Int[])
165+
143166
###########################################################
144167
#
145168
# Define Evaluate
@@ -148,20 +171,35 @@ SqEuclidean() = SqEuclidean(0)
148171

149172
const ArraySlice{T} = SubArray{T,1,Array{T,2},Tuple{Base.Slice{Base.OneTo{Int}},Int},true}
150173

174+
@inline parameters(::UnionMetrics) = nothing
175+
151176
# Specialized for Arrays and avoids a branch on the size
152177
@inline Base.@propagate_inbounds function evaluate(d::UnionMetrics, a::Union{Array, ArraySlice}, b::Union{Array, ArraySlice})
153178
@boundscheck if length(a) != length(b)
154179
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
155180
end
181+
p = parameters(d)
182+
@boundscheck if p !== nothing
183+
length(a) != length(p) && throw(DimensionMismatch("arrays have length $(length(a)) but parameters have length $(length(p))."))
184+
end
156185
if length(a) == 0
157186
return zero(result_type(d, a, b))
158187
end
159188
@inbounds begin
160189
s = eval_start(d, a, b)
161-
@simd for I in 1:length(a)
162-
ai = a[I]
163-
bi = b[I]
164-
s = eval_reduce(d, s, eval_op(d, ai, bi))
190+
if p === nothing
191+
@simd for I in 1:length(a)
192+
ai = a[I]
193+
bi = b[I]
194+
s = eval_reduce(d, s, eval_op(d, ai, bi))
195+
end
196+
else
197+
@simd for I in 1:length(a)
198+
aI = a[I]
199+
bI = b[I]
200+
pI = p[I]
201+
s = eval_reduce(d, s, eval_op(d, aI, bI, pI))
202+
end
165203
end
166204
return eval_end(d, s)
167205
end
@@ -171,29 +209,53 @@ end
171209
@boundscheck if length(a) != length(b)
172210
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
173211
end
212+
p = parameters(d)
213+
@boundscheck if p !== nothing
214+
length(a) != length(p) && throw(DimensionMismatch("arrays have length $(length(a)) but parameters have length $(length(p))."))
215+
end
174216
if length(a) == 0
175217
return zero(result_type(d, a, b))
176218
end
177219
@inbounds begin
178220
s = eval_start(d, a, b)
179221
if size(a) == size(b)
180-
@simd for I in eachindex(a, b)
181-
ai = a[I]
182-
bi = b[I]
183-
s = eval_reduce(d, s, eval_op(d, ai, bi))
222+
if p === nothing
223+
@simd for I in eachindex(a, b)
224+
ai = a[I]
225+
bi = b[I]
226+
s = eval_reduce(d, s, eval_op(d, ai, bi))
227+
end
228+
else
229+
@simd for I in eachindex(a, b, p)
230+
aI = a[I]
231+
bI = b[I]
232+
pI = p[I]
233+
s = eval_reduce(d, s, eval_op(d, aI, bI, pI))
234+
end
184235
end
185236
else
186-
for (Ia, Ib) in zip(eachindex(a), eachindex(b))
187-
ai = a[Ia]
188-
bi = b[Ib]
189-
s = eval_reduce(d, s, eval_op(d, ai, bi))
237+
if p === nothing
238+
for (Ia, Ib) in zip(eachindex(a), eachindex(b))
239+
ai = a[Ia]
240+
bi = b[Ib]
241+
s = eval_reduce(d, s, eval_op(d, ai, bi))
242+
end
243+
else
244+
for (Ia, Ib, Ip) in zip(eachindex(a), eachindex(b), eachindex(p))
245+
aI = a[Ia]
246+
bI = b[Ib]
247+
pI = p[Ip]
248+
s = eval_reduce(d, s, eval_op(d, aI, bI, pI))
249+
end
190250
end
191251
end
192252
end
193253
return eval_end(d, s)
194254
end
195255
result_type(dist::UnionMetrics, ::AbstractArray{T1}, ::AbstractArray{T2}) where {T1, T2} =
196-
typeof(eval_end(dist, eval_op(dist, one(T1), one(T2))))
256+
typeof(eval_end(dist, parameters(dist) === nothing ?
257+
eval_op(dist, one(T1), one(T2)) :
258+
eval_op(dist, one(T1), one(T2), one(eltype(dist)))))
197259
eval_start(d::UnionMetrics, a::AbstractArray, b::AbstractArray) =
198260
zero(result_type(d, a, b))
199261
eval_end(d::UnionMetrics, s) = s
@@ -214,6 +276,26 @@ eval_end(::Euclidean, s) = sqrt(s)
214276
euclidean(a::AbstractArray, b::AbstractArray) = evaluate(Euclidean(), a, b)
215277
euclidean(a::Number, b::Number) = evaluate(Euclidean(), a, b)
216278

279+
# PeriodicEuclidean
280+
Base.eltype(d::PeriodicEuclidean) = eltype(d.periods)
281+
@inline parameters(d::PeriodicEuclidean) = d.periods
282+
@inline function eval_op(d::PeriodicEuclidean, ai, bi, p)
283+
s1 = abs(ai - bi)
284+
s2 = mod(s1, p)
285+
s3 = min(s2, p - s2)
286+
abs2(s3)
287+
end
288+
@inline eval_reduce(::PeriodicEuclidean, s1, s2) = s1 + s2
289+
@inline eval_end(::PeriodicEuclidean, s) = sqrt(s)
290+
function evaluate(dist::PeriodicEuclidean, a::T, b::T) where {T <: Real}
291+
p = first(dist.periods)
292+
d = mod(abs(a - b), p)
293+
min(d, p - d)
294+
end
295+
peuclidean(a::AbstractArray, b::AbstractArray, p::AbstractArray{<: Real}) =
296+
evaluate(PeriodicEuclidean(p), a, b)
297+
peuclidean(a::Number, b::Number, p::Real) = evaluate(PeriodicEuclidean([p]), a, b)
298+
217299
# Cityblock
218300
@inline eval_op(::Cityblock, ai, bi) = abs(ai - bi)
219301
@inline eval_reduce(::Cityblock, s1, s2) = s1 + s2

test/F64.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Base.log(a::F64) = F64(log(a.x))
2929
Base.isfinite(a::F64) = isfinite(a.x)
3030
Base.float(a::F64) = a.x
3131
Base.rtoldefault(a::Type{F64}, b::Type{F64}) = Base.rtoldefault(Float64, Float64)
32+
Base.mod(a::F64, b::F64) = mod(a.x, b.x)
3233
# comparison
3334
Base.isapprox(a::F64, b::F64) = isapprox(a.x, b.x)
3435
Base.:<(a::F64, b::F64) = a.x < b.x

0 commit comments

Comments
 (0)