Skip to content

Commit a37bdff

Browse files
Add Wasserstein distance
1 parent 7f3a28c commit a37bdff

File tree

3 files changed

+124
-0
lines changed

3 files changed

+124
-0
lines changed

src/Distances.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ export
9191
bhattacharyya,
9292
hellinger,
9393
bregman,
94+
wasserstein_distance,
9495

9596
haversine,
9697

@@ -107,5 +108,6 @@ include("haversine.jl")
107108
include("mahalanobis.jl")
108109
include("bhattacharyya.jl")
109110
include("bregman.jl")
111+
include("wasserstein.jl")
110112

111113
end # module end

src/wasserstein.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#####
2+
##### Wasserstein distance
3+
#####
4+
5+
abstract type Side end
6+
struct Left <: Side end
7+
struct Right <: Side end
8+
9+
"""
10+
pysearchsorted(a,b;side="left")
11+
12+
Based on accepted answer in:
13+
https://stackoverflow.com/questions/55339848/julia-vectorized-version-of-searchsorted
14+
"""
15+
pysearchsorted(a,b,::Left) = searchsortedfirst.(Ref(a),b) .- 1
16+
pysearchsorted(a,b,::Right) = searchsortedlast.(Ref(a),b)
17+
18+
function compute_integral(u_cdf, v_cdf, deltas, p)
19+
if p == 1
20+
return sum(abs.(u_cdf - v_cdf) .* deltas)
21+
end
22+
if p == 2
23+
return sqrt(sum((u_cdf - v_cdf).^2 .* deltas))
24+
end
25+
return sum(abs.(u_cdf - v_cdf).^p .* deltas)^(1/p)
26+
end
27+
28+
function _cdf_distance(p, u_values, v_values, u_weights=nothing, v_weights=nothing)
29+
_validate_distribution(u_values, u_weights)
30+
_validate_distribution(v_values, v_weights)
31+
32+
u_sorter = sortperm(u_values)
33+
v_sorter = sortperm(v_values)
34+
35+
all_values = vcat(u_values, v_values)
36+
sort!(all_values)
37+
38+
# Compute the differences between pairs of successive values of u and v.
39+
deltas = diff(all_values)
40+
41+
# Get the respective positions of the values of u and v among the values of
42+
# both distributions.
43+
u_cdf_indices = pysearchsorted(u_values[u_sorter],all_values[1:end-1], Right())
44+
v_cdf_indices = pysearchsorted(v_values[v_sorter],all_values[1:end-1], Right())
45+
46+
# Calculate the CDFs of u and v using their weights, if specified.
47+
if u_weights == nothing
48+
u_cdf = (u_cdf_indices) / length(u_values)
49+
else
50+
u_sorted_cumweights = vcat([0], cumsum(u_weights[u_sorter]))
51+
u_cdf = u_sorted_cumweights[u_cdf_indices.+1] / u_sorted_cumweights[end]
52+
end
53+
54+
if v_weights == nothing
55+
v_cdf = (v_cdf_indices) / length(v_values)
56+
else
57+
v_sorted_cumweights = vcat([0], cumsum(v_weights[v_sorter]))
58+
v_cdf = v_sorted_cumweights[v_cdf_indices.+1] / v_sorted_cumweights[end]
59+
end
60+
61+
# Compute the value of the integral based on the CDFs.
62+
return compute_integral(u_cdf, v_cdf, deltas, p)
63+
end
64+
65+
function _validate_distribution(vals, weights)
66+
# Validate the value array.
67+
length(vals) == 0 && throw(ArgumentError("Distribution can't be empty."))
68+
# Validate the weight array, if specified.
69+
if weights nothing
70+
if length(weights) != length(vals)
71+
throw(DimensionMismatch("Value and weight array-likes for the same empirical distribution must be of the same size."))
72+
end
73+
any(weights .< 0) && throw(ArgumentError("All weights must be non-negative."))
74+
if !(0 < sum(weights) < Inf)
75+
throw(ArgumentError("Weight array-like sum must be positive and finite. Set as None for an equal distribution of weight."))
76+
end
77+
end
78+
return nothing
79+
end
80+
81+
"""
82+
wasserstein_distance(u_values, v_values, u_weights=nothing, v_weights=nothing)
83+
84+
Compute the first Wasserstein distance between two 1D distributions.
85+
This distance is also known as the earth mover's distance, since it can be
86+
seen as the minimum amount of "work" required to transform ``u`` into
87+
``v``, where "work" is measured as the amount of distribution weight
88+
that must be moved, multiplied by the distance it has to be moved.
89+
90+
- `u_values` Values observed in the (empirical) distribution.
91+
- `v_values` Values observed in the (empirical) distribution.
92+
93+
- `u_weights` Weight for each value.
94+
- `v_weights` Weight for each value.
95+
96+
If the weight sum differs from 1, it must still be positive
97+
and finite so that the weights can be normalized to sum to 1.
98+
"""
99+
function wasserstein_distance(u_values, v_values, u_weights=nothing, v_weights=nothing)
100+
return _cdf_distance(1, u_values, v_values, u_weights, v_weights)
101+
end

test/test_dists.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,27 @@ end
596596
@test bregman(F, ∇, p, q) ISdist(p, q)
597597
end
598598

599+
@testset "Wasserstein (Earth mover's) distance" begin
600+
Random.seed!(123)
601+
N = 5
602+
u_values = rand(N)
603+
v_values = rand(N)
604+
u_weights = rand(N)
605+
v_weights = rand(N)
606+
@test_throws ArgumentError wasserstein_distance([], [])
607+
@test_throws ArgumentError wasserstein_distance([], v_values)
608+
@test_throws ArgumentError wasserstein_distance(u_values, [])
609+
@test_throws DimensionMismatch wasserstein_distance(u_values, v_values, u_weights[1:end-1], v_weights)
610+
@test_throws DimensionMismatch wasserstein_distance(u_values, v_values, u_weights, v_weights[1:end-1])
611+
@test_throws ArgumentError wasserstein_distance(u_values, v_values, -u_weights, v_weights)
612+
@test_throws ArgumentError wasserstein_distance(u_values, v_values, u_weights, -v_weights)
613+
614+
# TODO: Needs better/more correctness tests
615+
@test wasserstein_distance(u_values, v_values) 0.2826796049559892
616+
@test wasserstein_distance(u_values, v_values, u_weights, v_weights) 0.28429147575475444
617+
618+
end
619+
599620
#=
600621
@testset "zero allocation colwise!" begin
601622
d = Euclidean()

0 commit comments

Comments
 (0)