1
+ # ####
2
+ # #### Wasserstein distance
3
+ # ####
4
+
5
+ export Wasserstein
6
+
7
+ # TODO : Make concrete
8
+ struct Wasserstein{T<: AbstractFloat } <: PreMetric
9
+ u_weights:: Union{AbstractArray{T}, Nothing}
10
+ v_weights:: Union{AbstractArray{T}, Nothing}
11
+ end
12
+
13
+ Wasserstein (u_weights, v_weights) = Wasserstein {eltype(u_weights)} (u_weights, v_weights)
14
+
15
+ (w:: Wasserstein )(u, v) = wasserstein (u, v, w. u_weights, w. v_weights)
16
+
17
+ evaluate (dist:: Wasserstein , u, v) = dist (u,v)
18
+
19
+ abstract type Side end
20
+ struct Left <: Side end
21
+ struct Right <: Side end
22
+
23
+ """
24
+ pysearchsorted(a,b;side="left")
25
+
26
+ Based on accepted answer in:
27
+ https://stackoverflow.com/questions/55339848/julia-vectorized-version-of-searchsorted
28
+ """
29
+ pysearchsorted (a,b,:: Left ) = searchsortedfirst .(Ref (a),b) .- 1
30
+ pysearchsorted (a,b,:: Right ) = searchsortedlast .(Ref (a),b)
31
+
32
+ function compute_integral (u_cdf, v_cdf, deltas, p)
33
+ if p == 1
34
+ return sum (abs .(u_cdf - v_cdf) .* deltas)
35
+ end
36
+ if p == 2
37
+ return sqrt (sum ((u_cdf - v_cdf). ^ 2 .* deltas))
38
+ end
39
+ return sum (abs .(u_cdf - v_cdf).^ p .* deltas)^ (1 / p)
40
+ end
41
+
42
+ function _cdf_distance (p, u_values, v_values, u_weights= nothing , v_weights= nothing )
43
+ _validate_distribution (u_values, u_weights)
44
+ _validate_distribution (v_values, v_weights)
45
+
46
+ u_sorter = sortperm (u_values)
47
+ v_sorter = sortperm (v_values)
48
+
49
+ all_values = vcat (u_values, v_values)
50
+ sort! (all_values)
51
+
52
+ # Compute the differences between pairs of successive values of u and v.
53
+ deltas = diff (all_values)
54
+
55
+ # Get the respective positions of the values of u and v among the values of
56
+ # both distributions.
57
+ u_cdf_indices = pysearchsorted (u_values[u_sorter],all_values[1 : end - 1 ], Right ())
58
+ v_cdf_indices = pysearchsorted (v_values[v_sorter],all_values[1 : end - 1 ], Right ())
59
+
60
+ # Calculate the CDFs of u and v using their weights, if specified.
61
+ if u_weights == nothing
62
+ u_cdf = (u_cdf_indices) / length (u_values)
63
+ else
64
+ u_sorted_cumweights = vcat ([0 ], cumsum (u_weights[u_sorter]))
65
+ u_cdf = u_sorted_cumweights[u_cdf_indices.+ 1 ] / u_sorted_cumweights[end ]
66
+ end
67
+
68
+ if v_weights == nothing
69
+ v_cdf = (v_cdf_indices) / length (v_values)
70
+ else
71
+ v_sorted_cumweights = vcat ([0 ], cumsum (v_weights[v_sorter]))
72
+ v_cdf = v_sorted_cumweights[v_cdf_indices.+ 1 ] / v_sorted_cumweights[end ]
73
+ end
74
+
75
+ # Compute the value of the integral based on the CDFs.
76
+ return compute_integral (u_cdf, v_cdf, deltas, p)
77
+ end
78
+
79
+ function _validate_distribution (vals, weights)
80
+ # Validate the value array.
81
+ length (vals) == 0 && throw (ArgumentError (" Distribution can't be empty." ))
82
+ # Validate the weight array, if specified.
83
+ if weights ≠ nothing
84
+ if length (weights) != length (vals)
85
+ throw (DimensionMismatch (" Value and weight array-likes for the same empirical distribution must be of the same size." ))
86
+ end
87
+ any (weights .< 0 ) && throw (ArgumentError (" All weights must be non-negative." ))
88
+ if ! (0 < sum (weights) < Inf )
89
+ throw (ArgumentError (" Weight array-like sum must be positive and finite. Set as None for an equal distribution of weight." ))
90
+ end
91
+ end
92
+ return nothing
93
+ end
94
+
95
+ """
96
+ wasserstein(u_values, v_values, u_weights=nothing, v_weights=nothing)
97
+
98
+ Compute the first Wasserstein distance between two 1D distributions.
99
+ This distance is also known as the earth mover's distance, since it can be
100
+ seen as the minimum amount of "work" required to transform ``u`` into
101
+ ``v``, where "work" is measured as the amount of distribution weight
102
+ that must be moved, multiplied by the distance it has to be moved.
103
+
104
+ - `u_values` Values observed in the (empirical) distribution.
105
+ - `v_values` Values observed in the (empirical) distribution.
106
+
107
+ - `u_weights` Weight for each value.
108
+ - `v_weights` Weight for each value.
109
+
110
+ If the weight sum differs from 1, it must still be positive
111
+ and finite so that the weights can be normalized to sum to 1.
112
+ """
113
+ function wasserstein (u_values, v_values, u_weights= nothing , v_weights= nothing )
114
+ return _cdf_distance (1 , u_values, v_values, u_weights, v_weights)
115
+ end
0 commit comments