Skip to content

Commit 1e88fe4

Browse files
Add Wasserstein distance
1 parent 7f3a28c commit 1e88fe4

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed

src/Distances.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export
5858
RMSDeviation,
5959
NormRMSDeviation,
6060
Bregman,
61+
Wasserstein,
6162

6263
# convenient functions
6364
euclidean,
@@ -91,6 +92,7 @@ export
9192
bhattacharyya,
9293
hellinger,
9394
bregman,
95+
wasserstein,
9496

9597
haversine,
9698

@@ -107,5 +109,6 @@ include("haversine.jl")
107109
include("mahalanobis.jl")
108110
include("bhattacharyya.jl")
109111
include("bregman.jl")
112+
include("wasserstein.jl")
110113

111114
end # module end

src/wasserstein.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#####
2+
##### Wasserstein distance
3+
#####
4+
5+
export Wasserstein
6+
7+
# TODO: Make concrete
8+
struct Wasserstein{T<:AbstractFloat} <: PreMetric
9+
u_weights::Union{AbstractArray{T}, Nothing}
10+
v_weights::Union{AbstractArray{T}, Nothing}
11+
end
12+
13+
Wasserstein(u_weights, v_weights) = Wasserstein{eltype(u_weights)}(u_weights, v_weights)
14+
15+
(w::Wasserstein)(u, v) = wasserstein(u, v, w.u_weights, w.v_weights)
16+
17+
evaluate(dist::Wasserstein, u, v) = dist(u,v)
18+
19+
abstract type Side end
20+
struct Left <: Side end
21+
struct Right <: Side end
22+
23+
"""
24+
pysearchsorted(a,b;side="left")
25+
26+
Based on accepted answer in:
27+
https://stackoverflow.com/questions/55339848/julia-vectorized-version-of-searchsorted
28+
"""
29+
pysearchsorted(a,b,::Left) = searchsortedfirst.(Ref(a),b) .- 1
30+
pysearchsorted(a,b,::Right) = searchsortedlast.(Ref(a),b)
31+
32+
function compute_integral(u_cdf, v_cdf, deltas, p)
33+
if p == 1
34+
return sum(abs.(u_cdf - v_cdf) .* deltas)
35+
end
36+
if p == 2
37+
return sqrt(sum((u_cdf - v_cdf).^2 .* deltas))
38+
end
39+
return sum(abs.(u_cdf - v_cdf).^p .* deltas)^(1/p)
40+
end
41+
42+
function _cdf_distance(p, u_values, v_values, u_weights=nothing, v_weights=nothing)
43+
_validate_distribution(u_values, u_weights)
44+
_validate_distribution(v_values, v_weights)
45+
46+
u_sorter = sortperm(u_values)
47+
v_sorter = sortperm(v_values)
48+
49+
all_values = vcat(u_values, v_values)
50+
sort!(all_values)
51+
52+
# Compute the differences between pairs of successive values of u and v.
53+
deltas = diff(all_values)
54+
55+
# Get the respective positions of the values of u and v among the values of
56+
# both distributions.
57+
u_cdf_indices = pysearchsorted(u_values[u_sorter],all_values[1:end-1], Right())
58+
v_cdf_indices = pysearchsorted(v_values[v_sorter],all_values[1:end-1], Right())
59+
60+
# Calculate the CDFs of u and v using their weights, if specified.
61+
if u_weights == nothing
62+
u_cdf = (u_cdf_indices) / length(u_values)
63+
else
64+
u_sorted_cumweights = vcat([0], cumsum(u_weights[u_sorter]))
65+
u_cdf = u_sorted_cumweights[u_cdf_indices.+1] / u_sorted_cumweights[end]
66+
end
67+
68+
if v_weights == nothing
69+
v_cdf = (v_cdf_indices) / length(v_values)
70+
else
71+
v_sorted_cumweights = vcat([0], cumsum(v_weights[v_sorter]))
72+
v_cdf = v_sorted_cumweights[v_cdf_indices.+1] / v_sorted_cumweights[end]
73+
end
74+
75+
# Compute the value of the integral based on the CDFs.
76+
return compute_integral(u_cdf, v_cdf, deltas, p)
77+
end
78+
79+
function _validate_distribution(vals, weights)
80+
# Validate the value array.
81+
length(vals) == 0 && throw(ArgumentError("Distribution can't be empty."))
82+
# Validate the weight array, if specified.
83+
if weights nothing
84+
if length(weights) != length(vals)
85+
throw(DimensionMismatch("Value and weight array-likes for the same empirical distribution must be of the same size."))
86+
end
87+
any(weights .< 0) && throw(ArgumentError("All weights must be non-negative."))
88+
if !(0 < sum(weights) < Inf)
89+
throw(ArgumentError("Weight array-like sum must be positive and finite. Set as None for an equal distribution of weight."))
90+
end
91+
end
92+
return nothing
93+
end
94+
95+
"""
96+
wasserstein(u_values, v_values, u_weights=nothing, v_weights=nothing)
97+
98+
Compute the first Wasserstein distance between two 1D distributions.
99+
This distance is also known as the earth mover's distance, since it can be
100+
seen as the minimum amount of "work" required to transform ``u`` into
101+
``v``, where "work" is measured as the amount of distribution weight
102+
that must be moved, multiplied by the distance it has to be moved.
103+
104+
- `u_values` Values observed in the (empirical) distribution.
105+
- `v_values` Values observed in the (empirical) distribution.
106+
107+
- `u_weights` Weight for each value.
108+
- `v_weights` Weight for each value.
109+
110+
If the weight sum differs from 1, it must still be positive
111+
and finite so that the weights can be normalized to sum to 1.
112+
"""
113+
function wasserstein(u_values, v_values, u_weights=nothing, v_weights=nothing)
114+
return _cdf_distance(1, u_values, v_values, u_weights, v_weights)
115+
end

test/test_dists.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,38 @@ end
596596
@test bregman(F, ∇, p, q) ISdist(p, q)
597597
end
598598

599+
@testset "Wasserstein (Earth mover's) distance" begin
600+
Random.seed!(123)
601+
for T in [Float64]
602+
# for T in [Float32, Float64]
603+
N = 5
604+
u = rand(T, N)
605+
v = rand(T, N)
606+
u_weights = rand(T, N)
607+
v_weights = rand(T, N)
608+
609+
dist = Wasserstein(u_weights, v_weights)
610+
611+
test_pairwise(dist, u, v, T)
612+
613+
@test evaluate(dist, u, v) === wasserstein(u, v, u_weights, v_weights)
614+
@test dist(u, v) === wasserstein(u, v, u_weights, v_weights)
615+
616+
@test_throws ArgumentError wasserstein([], [])
617+
@test_throws ArgumentError wasserstein([], v)
618+
@test_throws ArgumentError wasserstein(u, [])
619+
@test_throws DimensionMismatch wasserstein(u, v, u_weights[1:end-1], v_weights)
620+
@test_throws DimensionMismatch wasserstein(u, v, u_weights, v_weights[1:end-1])
621+
@test_throws ArgumentError wasserstein(u, v, -u_weights, v_weights)
622+
@test_throws ArgumentError wasserstein(u, v, u_weights, -v_weights)
623+
624+
# # TODO: Needs better/more correctness tests
625+
# @test wasserstein(u, v) ≈ 0.2826796049559892
626+
# @test wasserstein(u, v, u_weights, v_weights) ≈ 0.28429147575475444
627+
end
628+
629+
end
630+
599631
#=
600632
@testset "zero allocation colwise!" begin
601633
d = Euclidean()

0 commit comments

Comments
 (0)