Skip to content

Commit 484d20b

Browse files
Add Wasserstein distance
1 parent 7f3a28c commit 484d20b

File tree

3 files changed

+702
-552
lines changed

3 files changed

+702
-552
lines changed

src/Distances.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export
5858
RMSDeviation,
5959
NormRMSDeviation,
6060
Bregman,
61+
Wasserstein,
6162

6263
# convenient functions
6364
euclidean,
@@ -91,6 +92,7 @@ export
9192
bhattacharyya,
9293
hellinger,
9394
bregman,
95+
wasserstein,
9496

9597
haversine,
9698

@@ -107,5 +109,6 @@ include("haversine.jl")
107109
include("mahalanobis.jl")
108110
include("bhattacharyya.jl")
109111
include("bregman.jl")
112+
include("wasserstein.jl")
110113

111114
end # module end

src/wasserstein.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#####
2+
##### Wasserstein distance
3+
#####
4+
5+
export Wasserstein
6+
7+
# TODO: Make concrete
8+
struct Wasserstein{T<:AbstractFloat} <: PreMetric
9+
u_weights::Union{AbstractArray{T}, Nothing}
10+
v_weights::Union{AbstractArray{T}, Nothing}
11+
end
12+
13+
Wasserstein(u_weights, v_weights) = Wasserstein{eltype(u_weights)}(u_weights, v_weights)
14+
15+
(w::Wasserstein)(u, v) = wasserstein(u, v, w.u_weights, w.v_weights)
16+
17+
evaluate(dist::Wasserstein, u, v) = dist(u,v)
18+
19+
abstract type Side end
20+
struct Left <: Side end
21+
struct Right <: Side end
22+
23+
"""
24+
pysearchsorted(a,b;side="left")
25+
26+
Based on accepted answer in:
27+
https://stackoverflow.com/questions/55339848/julia-vectorized-version-of-searchsorted
28+
"""
29+
pysearchsorted(a,b,::Left) = searchsortedfirst.(Ref(a),b) .- 1
30+
pysearchsorted(a,b,::Right) = searchsortedlast.(Ref(a),b)
31+
32+
function compute_integral(u_cdf, v_cdf, deltas, p)
33+
if p == 1
34+
return sum(abs.(u_cdf - v_cdf) .* deltas)
35+
end
36+
if p == 2
37+
return sqrt(sum((u_cdf - v_cdf).^2 .* deltas))
38+
end
39+
return sum(abs.(u_cdf - v_cdf).^p .* deltas)^(1/p)
40+
end
41+
42+
function _cdf_distance(p, u_values, v_values, u_weights=nothing, v_weights=nothing)
43+
_validate_distribution(u_values, u_weights)
44+
_validate_distribution(v_values, v_weights)
45+
46+
u_sorter = sortperm(u_values)
47+
v_sorter = sortperm(v_values)
48+
49+
all_values = vcat(u_values, v_values)
50+
sort!(all_values)
51+
52+
# Compute the differences between pairs of successive values of u and v.
53+
deltas = diff(all_values)
54+
55+
# Get the respective positions of the values of u and v among the values of
56+
# both distributions.
57+
u_cdf_indices = pysearchsorted(u_values[u_sorter],all_values[1:end-1], Right())
58+
v_cdf_indices = pysearchsorted(v_values[v_sorter],all_values[1:end-1], Right())
59+
60+
# Calculate the CDFs of u and v using their weights, if specified.
61+
if u_weights == nothing
62+
u_cdf = (u_cdf_indices) / length(u_values)
63+
else
64+
u_sorted_cumweights = vcat([0], cumsum(u_weights[u_sorter]))
65+
u_cdf = u_sorted_cumweights[u_cdf_indices.+1] / u_sorted_cumweights[end]
66+
end
67+
68+
if v_weights == nothing
69+
v_cdf = (v_cdf_indices) / length(v_values)
70+
else
71+
v_sorted_cumweights = vcat([0], cumsum(v_weights[v_sorter]))
72+
v_cdf = v_sorted_cumweights[v_cdf_indices.+1] / v_sorted_cumweights[end]
73+
end
74+
75+
# Compute the value of the integral based on the CDFs.
76+
return compute_integral(u_cdf, v_cdf, deltas, p)
77+
end
78+
79+
function _validate_distribution(vals, weights)
80+
# Validate the value array.
81+
length(vals) == 0 && throw(ArgumentError("Distribution can't be empty."))
82+
# Validate the weight array, if specified.
83+
if weights nothing
84+
if length(weights) != length(vals)
85+
throw(DimensionMismatch("Value and weight array-likes for the same empirical distribution must be of the same size."))
86+
end
87+
any(weights .< 0) && throw(ArgumentError("All weights must be non-negative."))
88+
if !(0 < sum(weights) < Inf)
89+
throw(ArgumentError("Weight array-like sum must be positive and finite. Set as None for an equal distribution of weight."))
90+
end
91+
end
92+
return nothing
93+
end
94+
95+
"""
96+
wasserstein(u_values, v_values, u_weights=nothing, v_weights=nothing)
97+
98+
Compute the first Wasserstein distance between two 1D distributions.
99+
This distance is also known as the earth mover's distance, since it can be
100+
seen as the minimum amount of "work" required to transform ``u`` into
101+
``v``, where "work" is measured as the amount of distribution weight
102+
that must be moved, multiplied by the distance it has to be moved.
103+
104+
- `u_values` Values observed in the (empirical) distribution.
105+
- `v_values` Values observed in the (empirical) distribution.
106+
107+
- `u_weights` Weight for each value.
108+
- `v_weights` Weight for each value.
109+
110+
If the weight sum differs from 1, it must still be positive
111+
and finite so that the weights can be normalized to sum to 1.
112+
"""
113+
function wasserstein(u_values, v_values, u_weights=nothing, v_weights=nothing)
114+
return _cdf_distance(1, u_values, v_values, u_weights, v_weights)
115+
end

0 commit comments

Comments
 (0)