Skip to content

Commit bd5ace4

Browse files
committed
Implement Histograms.jl; just Wasserstein-1 for now
1 parent f02f722 commit bd5ace4

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

src/Histograms.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
module Histograms
2+
"""
3+
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
4+
5+
Functions in this module:
6+
- wasserstein (2 methods)
7+
8+
"""
9+
10+
import PyCall
11+
scsta = PyCall.pyimport("scipy.stats")
12+
13+
################################################################################
14+
# distance functions ###########################################################
15+
################################################################################
16+
"""
17+
Compute the Wasserstein-1 distance between two distributions from their samples
18+
19+
Parameters:
20+
- u_samples: array-like; samples from the 1st distribution
21+
- v_samples: array-like; samples from the 2nd distribution
22+
- normalize: boolean; whether to normalize the distance by 1/(max-min)
23+
24+
Returns:
25+
- w1: number; the Wasserstein-1 distance
26+
"""
27+
function wasserstein(u_samples::AbstractVector, v_samples::AbstractVector;
28+
normalize = true)
29+
L = maximum([u_samples; v_samples]) - minimum([u_samples; v_samples])
30+
return if !normalize
31+
scsta.wasserstein_distance(u_samples, v_samples)
32+
else
33+
scsta.wasserstein_distance(u_samples, v_samples) / L
34+
end
35+
end
36+
37+
"""
38+
Compute the pairwise Wasserstein-1 distances between two sets of distributions
39+
from their samples
40+
41+
Parameters:
42+
- U_samples: matrix-like; samples from distributions (u1, u2, ...)
43+
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
44+
- normalize: boolean; whether to normalize the distances by 1/(max-min)
45+
46+
`U_samples` and `V_samples` should have samples in the 2nd dimension (along
47+
rows) and have the same 1st dimension (same number of rows). If not, the minimum
48+
of the two (minimum number of rows) will be taken.
49+
50+
`normalize` induces *pairwise* normalization, i.e. it max's and min's are
51+
computed for each pair (u_j, v_j) individually.
52+
53+
Returns:
54+
- w1: array-like; the pairwise Wasserstein-1 distances:
55+
w1(u1, v1)
56+
w1(u2, v2)
57+
...
58+
w1(u_K, v_K)
59+
"""
60+
function wasserstein(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
61+
normalize = true)
62+
if size(U_samples, 1) != size(V_samples, 1)
63+
println(warn("wasserstein"), "sizes of U_samples & V_samples don't match; ",
64+
"will use the minimum of the two")
65+
end
66+
K = min(size(U_samples, 1), size(V_samples, 1))
67+
w1 = zeros(K)
68+
U_sorted = sort(U_samples[1:K, 1:end], dims = 2)
69+
V_sorted = sort(V_samples[1:K, 1:end], dims = 2)
70+
for k in 1:K
71+
w1[k] = wasserstein(U_sorted[k, 1:end], V_sorted[k, 1:end];
72+
normalize = normalize)
73+
end
74+
return w1
75+
end
76+
77+
end # module
78+
79+

0 commit comments

Comments
 (0)