1
+ # ####
2
+ # #### Wasserstein distance
3
+ # ####
4
+
5
+ abstract type Side end
6
+ struct Left <: Side end
7
+ struct Right <: Side end
8
+
9
+ """
10
+ pysearchsorted(a,b;side="left")
11
+
12
+ Based on accepted answer in:
13
+ https://stackoverflow.com/questions/55339848/julia-vectorized-version-of-searchsorted
14
+ """
15
+ pysearchsorted (a,b,:: Left ) = searchsortedfirst .(Ref (a),b) .- 1
16
+ pysearchsorted (a,b,:: Right ) = searchsortedlast .(Ref (a),b)
17
+
18
+ function compute_integral (u_cdf, v_cdf, deltas, p)
19
+ if p == 1
20
+ return sum (abs .(u_cdf - v_cdf) .* deltas)
21
+ end
22
+ if p == 2
23
+ return sqrt (sum ((u_cdf - v_cdf). ^ 2 .* deltas))
24
+ end
25
+ return sum (abs .(u_cdf - v_cdf).^ p .* deltas)^ (1 / p)
26
+ end
27
+
28
+ function _cdf_distance (p, u_values, v_values, u_weights= nothing , v_weights= nothing )
29
+ _validate_distribution (u_values, u_weights)
30
+ _validate_distribution (v_values, v_weights)
31
+
32
+ u_sorter = sortperm (u_values)
33
+ v_sorter = sortperm (v_values)
34
+
35
+ all_values = vcat (u_values, v_values)
36
+ sort! (all_values)
37
+
38
+ # Compute the differences between pairs of successive values of u and v.
39
+ deltas = diff (all_values)
40
+
41
+ # Get the respective positions of the values of u and v among the values of
42
+ # both distributions.
43
+ u_cdf_indices = pysearchsorted (u_values[u_sorter],all_values[1 : end - 1 ], Right ())
44
+ v_cdf_indices = pysearchsorted (v_values[v_sorter],all_values[1 : end - 1 ], Right ())
45
+
46
+ # Calculate the CDFs of u and v using their weights, if specified.
47
+ if u_weights == nothing
48
+ u_cdf = (u_cdf_indices) / length (u_values)
49
+ else
50
+ u_sorted_cumweights = vcat ([0 ], cumsum (u_weights[u_sorter]))
51
+ u_cdf = u_sorted_cumweights[u_cdf_indices.+ 1 ] / u_sorted_cumweights[end ]
52
+ end
53
+
54
+ if v_weights == nothing
55
+ v_cdf = (v_cdf_indices) / length (v_values)
56
+ else
57
+ v_sorted_cumweights = vcat ([0 ], cumsum (v_weights[v_sorter]))
58
+ v_cdf = v_sorted_cumweights[v_cdf_indices.+ 1 ] / v_sorted_cumweights[end ]
59
+ end
60
+
61
+ # Compute the value of the integral based on the CDFs.
62
+ return compute_integral (u_cdf, v_cdf, deltas, p)
63
+ end
64
+
65
+ function _validate_distribution (vals, weights)
66
+ # Validate the value array.
67
+ length (vals) == 0 && throw (ArgumentError (" Distribution can't be empty." ))
68
+ # Validate the weight array, if specified.
69
+ if weights ≠ nothing
70
+ if length (weights) != length (vals)
71
+ throw (DimensionMismatch (" Value and weight array-likes for the same empirical distribution must be of the same size." ))
72
+ end
73
+ any (weights .< 0 ) && throw (ArgumentError (" All weights must be non-negative." ))
74
+ if ! (0 < sum (weights) < Inf )
75
+ throw (ArgumentError (" Weight array-like sum must be positive and finite. Set as None for an equal distribution of weight." ))
76
+ end
77
+ end
78
+ return nothing
79
+ end
80
+
81
+ """
82
+ wasserstein_distance(u_values, v_values, u_weights=nothing, v_weights=nothing)
83
+
84
+ Compute the first Wasserstein distance between two 1D distributions.
85
+ This distance is also known as the earth mover's distance, since it can be
86
+ seen as the minimum amount of "work" required to transform ``u`` into
87
+ ``v``, where "work" is measured as the amount of distribution weight
88
+ that must be moved, multiplied by the distance it has to be moved.
89
+
90
+ - `u_values` Values observed in the (empirical) distribution.
91
+ - `v_values` Values observed in the (empirical) distribution.
92
+
93
+ - `u_weights` Weight for each value.
94
+ - `v_weights` Weight for each value.
95
+
96
+ If the weight sum differs from 1, it must still be positive
97
+ and finite so that the weights can be normalized to sum to 1.
98
+ """
99
+ function wasserstein_distance (u_values, v_values, u_weights= nothing , v_weights= nothing )
100
+ return _cdf_distance (1 , u_values, v_values, u_weights, v_weights)
101
+ end
0 commit comments