Skip to content

Commit b9a5f38

Browse files
Add Wasserstein distance
1 parent a339344 commit b9a5f38

File tree

2 files changed

+82
-7
lines changed

2 files changed

+82
-7
lines changed

src/Histograms.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,10 @@ Functions in this module:
1111
"""
1212
module Histograms
1313

14-
using PyCall
1514
using NPZ
1615
using ..Utilities
1716

18-
using Conda
19-
Conda.add("scipy")
20-
scsta = pyimport_conda("scipy.stats","")
17+
include("wasserstein.jl")
2118

2219
"""
2320
A simple struct to store samples for empirical PDFs (histograms, distances etc.)
@@ -196,13 +193,14 @@ Returns:
196193
"""
197194
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
198195
normalize = false)
199-
return if !normalize
200-
scsta.wasserstein_distance(u_samples, v_samples)
196+
d = wasserstein_distance(u_samples, v_samples)
197+
if !normalize
198+
return d
201199
else
202200
u_m, u_M = extrema(u_samples)
203201
v_m, v_M = extrema(v_samples)
204202
L = max(u_M, v_M) - min(u_m, v_m)
205-
scsta.wasserstein_distance(u_samples, v_samples) / L
203+
return d / L
206204
end
207205
end
208206

src/wasserstein.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#####
2+
##### Wasserstein distance
3+
#####
4+
5+
function pysearchsorted(a,b;side="left")
6+
if side == "left"
7+
return searchsortedfirst.(Ref(a),b) .- 1
8+
else
9+
return searchsortedlast.(Ref(a),b)
10+
end
11+
end
12+
13+
function _cdf_distance(p, u_values, v_values, u_weights=nothing, v_weights=nothing)
14+
_validate_distribution!(u_values, u_weights)
15+
_validate_distribution!(v_values, v_weights)
16+
17+
u_sorter = sortperm(u_values)
18+
v_sorter = sortperm(v_values)
19+
20+
all_values = vcat(u_values, v_values)
21+
sort!(all_values)
22+
23+
# Compute the differences between pairs of successive values of u and v.
24+
deltas = diff(all_values)
25+
26+
# Get the respective positions of the values of u and v among the values of
27+
# both distributions.
28+
u_cdf_indices = pysearchsorted(u_values[u_sorter],all_values[1:end-1], side="right")
29+
v_cdf_indices = pysearchsorted(v_values[v_sorter],all_values[1:end-1], side="right")
30+
31+
# Calculate the CDFs of u and v using their weights, if specified.
32+
if u_weights == nothing
33+
u_cdf = (u_cdf_indices) / length(u_values)
34+
else
35+
u_sorted_cumweights = vcat([0], cumsum(u_weights[u_sorter]))
36+
u_cdf = u_sorted_cumweights[u_cdf_indices] / u_sorted_cumweights[end]
37+
end
38+
39+
if v_weights == nothing
40+
v_cdf = (v_cdf_indices) / length(v_values)
41+
else
42+
v_sorted_cumweights = vcat([0], cumsum(v_weights[v_sorter]))
43+
v_cdf = v_sorted_cumweights[v_cdf_indices] / v_sorted_cumweights[end]
44+
end
45+
46+
# Compute the value of the integral based on the CDFs.
47+
if p == 1
48+
return sum(abs.(u_cdf - v_cdf) .* deltas)
49+
end
50+
if p == 2
51+
return sqrt(sum((u_cdf - v_cdf).^2 .* deltas))
52+
end
53+
return sum(abs.(u_cdf - v_cdf).^p .* deltas)^(1/p)
54+
end
55+
56+
function _validate_distribution!(vals, weights)
57+
# Validate the value array.
58+
length(vals) == 0 && throw(ValueError("Distribution can't be empty."))
59+
# Validate the weight array, if specified.
60+
if weights nothing
61+
if length(weights) != length(vals)
62+
throw(ValueError("Value and weight array-likes for the same
63+
empirical distribution must be of the same size."))
64+
end
65+
any(weights .< 0) && throw(ValueError("All weights must be non-negative."))
66+
if !(0 < sum(weights) < Inf)
67+
throw(ValueError("Weight array-like sum must be positive and
68+
finite. Set as None for an equal distribution of
69+
weight."))
70+
end
71+
end
72+
return nothing
73+
end
74+
75+
function wasserstein_distance(u_values, v_values, u_weights=nothing, v_weights=nothing)
76+
return _cdf_distance(1, u_values, v_values, u_weights, v_weights)
77+
end

0 commit comments

Comments
 (0)