Skip to content

Commit 6b82d18

Browse files
Add EarthMoversDistance in Julia
1 parent 59edfb2 commit 6b82d18

File tree

4 files changed

+83
-10
lines changed

4 files changed

+83
-10
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "0.1.0"
55

66
[deps]
77
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
8-
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
98
ConfParser = "88353bc9-fd38-507d-a820-d3b43837d6b9"
109
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
1110
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -18,7 +17,6 @@ NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
1817
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1918
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
2019
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
21-
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
2220
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
2321
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2422
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"

src/Histograms.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,10 @@ Functions in this module:
1111
"""
1212
module Histograms
1313

14-
using PyCall
1514
using NPZ
1615
using ..Utilities
1716

18-
using Conda
19-
Conda.add("scipy")
20-
scsta = pyimport_conda("scipy.stats","")
17+
include("wasserstein.jl")
2118

2219
"""
2320
A simple struct to store samples for empirical PDFs (histograms, distances etc.)
@@ -196,13 +193,14 @@ Returns:
196193
"""
197194
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
198195
normalize = false)
199-
return if !normalize
200-
scsta.wasserstein_distance(u_samples, v_samples)
196+
d = wasserstein_distance(u_samples, v_samples)
197+
if !normalize
198+
return d
201199
else
202200
u_m, u_M = extrema(u_samples)
203201
v_m, v_M = extrema(v_samples)
204202
L = max(u_M, v_M) - min(u_m, v_m)
205-
scsta.wasserstein_distance(u_samples, v_samples) / L
203+
return d / L
206204
end
207205
end
208206

src/wasserstein.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#####
2+
##### Wasserstein distance
3+
#####
4+
5+
function pysearchsorted(a,b;side="left")
6+
if side == "left"
7+
return searchsortedfirst.(Ref(a),b) .- 1
8+
else
9+
return searchsortedlast.(Ref(a),b)
10+
end
11+
end
12+
13+
function _cdf_distance(p, u_values, v_values, u_weights=nothing, v_weights=nothing)
14+
_validate_distribution!(u_values, u_weights)
15+
_validate_distribution!(v_values, v_weights)
16+
17+
u_sorter = sortperm(u_values)
18+
v_sorter = sortperm(v_values)
19+
20+
all_values = vcat(u_values, v_values)
21+
sort!(all_values)
22+
23+
# Compute the differences between pairs of successive values of u and v.
24+
deltas = diff(all_values)
25+
26+
# Get the respective positions of the values of u and v among the values of
27+
# both distributions.
28+
u_cdf_indices = pysearchsorted(u_values[u_sorter],all_values[1:end-1], side="right")
29+
v_cdf_indices = pysearchsorted(v_values[v_sorter],all_values[1:end-1], side="right")
30+
31+
# Calculate the CDFs of u and v using their weights, if specified.
32+
if u_weights == nothing
33+
u_cdf = (u_cdf_indices) / length(u_values)
34+
else
35+
u_sorted_cumweights = vcat([0], cumsum(u_weights[u_sorter]))
36+
u_cdf = u_sorted_cumweights[u_cdf_indices] / u_sorted_cumweights[end]
37+
end
38+
39+
if v_weights == nothing
40+
v_cdf = (v_cdf_indices) / length(v_values)
41+
else
42+
v_sorted_cumweights = vcat([0], cumsum(v_weights[v_sorter]))
43+
v_cdf = v_sorted_cumweights[v_cdf_indices] / v_sorted_cumweights[end]
44+
end
45+
46+
# Compute the value of the integral based on the CDFs.
47+
if p == 1
48+
return sum(abs.(u_cdf - v_cdf) .* deltas)
49+
end
50+
if p == 2
51+
return sqrt(sum((u_cdf - v_cdf).^2 .* deltas))
52+
end
53+
return sum(abs.(u_cdf - v_cdf).^p .* deltas)^(1/p)
54+
end
55+
56+
function _validate_distribution!(vals, weights)
57+
# Validate the value array.
58+
length(vals) == 0 && throw(ValueError("Distribution can't be empty."))
59+
# Validate the weight array, if specified.
60+
if weights nothing
61+
if length(weights) != length(vals)
62+
throw(ValueError("Value and weight array-likes for the same
63+
empirical distribution must be of the same size."))
64+
end
65+
any(weights .< 0) && throw(ValueError("All weights must be non-negative."))
66+
if !(0 < sum(weights) < Inf)
67+
throw(ValueError("Weight array-like sum must be positive and
68+
finite. Set as None for an equal distribution of
69+
weight."))
70+
end
71+
end
72+
return nothing
73+
end
74+
75+
function wasserstein_distance(u_values, v_values, u_weights=nothing, v_weights=nothing)
76+
return _cdf_distance(1, u_values, v_values, u_weights, v_weights)
77+
end

test/Cloudy/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using Pkg; Pkg.add(PackageSpec(url="https://github.com/climate-machine/Cloudy.jl"))
44
using Cloudy
55
using Cloudy.KernelTensors
6-
PDistributions = Cloudy.Distributions
6+
PDistributions = Cloudy.ParticleDistributions
77
Pkg.add("Plots")
88

99
# Import modules

0 commit comments

Comments
 (0)