@@ -3,11 +3,13 @@ module Histograms
3
3
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
4
4
5
5
Functions in this module:
6
- - wasserstein (2 methods)
6
+ - W1 (2 methods)
7
7
8
8
"""
9
9
10
10
import PyCall
11
+ include (" ConvenienceFunctions.jl" )
12
+
11
13
scsta = PyCall. pyimport (" scipy.stats" )
12
14
13
15
# ###############################################################################
@@ -22,9 +24,9 @@ Parameters:
22
24
- normalize: boolean; whether to normalize the distance by 1/(max-min)
23
25
24
26
Returns:
25
- - w1: number; the Wasserstein-1 distance
27
+ - w1_uv: number; the Wasserstein-1 distance
26
28
"""
27
- function wasserstein (u_samples:: AbstractVector , v_samples:: AbstractVector ;
29
+ function W1 (u_samples:: AbstractVector , v_samples:: AbstractVector ;
28
30
normalize = true )
29
31
L = maximum ([u_samples; v_samples]) - minimum ([u_samples; v_samples])
30
32
return if ! normalize
@@ -51,27 +53,26 @@ of the two (minimum number of rows) will be taken.
51
53
computed for each pair (u_j, v_j) individually.
52
54
53
55
Returns:
54
- - w1: array-like; the pairwise Wasserstein-1 distances:
56
+ - w1_UV: array-like; the pairwise Wasserstein-1 distances:
55
57
w1(u1, v1)
56
58
w1(u2, v2)
57
59
...
58
60
w1(u_K, v_K)
59
61
"""
60
- function wasserstein (U_samples:: AbstractMatrix , V_samples:: AbstractMatrix ;
62
+ function W1 (U_samples:: AbstractMatrix , V_samples:: AbstractMatrix ;
61
63
normalize = true )
62
64
if size (U_samples, 1 ) != size (V_samples, 1 )
63
- println (warn (" wasserstein " ), " sizes of U_samples & V_samples don't match; " ,
65
+ println (warn (" W1 " ), " sizes of U_samples & V_samples don't match; " ,
64
66
" will use the minimum of the two" )
65
67
end
66
68
K = min (size (U_samples, 1 ), size (V_samples, 1 ))
67
- w1 = zeros (K)
69
+ w1_UV = zeros (K)
68
70
U_sorted = sort (U_samples[1 : K, 1 : end ], dims = 2 )
69
71
V_sorted = sort (V_samples[1 : K, 1 : end ], dims = 2 )
70
72
for k in 1 : K
71
- w1[k] = wasserstein (U_sorted[k, 1 : end ], V_sorted[k, 1 : end ];
72
- normalize = normalize)
73
+ w1_UV[k] = W1 (U_sorted[k, 1 : end ], V_sorted[k, 1 : end ]; normalize = normalize)
73
74
end
74
- return w1
75
+ return w1_UV
75
76
end
76
77
77
78
end # module
0 commit comments