Skip to content

Commit e622818

Browse files
authored
Merge pull request #10 from climate-machine/histograms
Histograms.jl: Implement Wasserstein-1 distance
2 parents 6e63890 + e3bf4a2 commit e622818

File tree

10 files changed

+152
-13
lines changed

10 files changed

+152
-13
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
14+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
1415
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

src/ConvenienceFunctions.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
const RPAD = 25
3+
4+
function name(name::AbstractString)
5+
return rpad(name * ":", RPAD)
6+
end
7+
8+
function warn(name::AbstractString)
9+
return rpad("WARNING (" * name * "):", RPAD)
10+
end
11+
12+

src/GPR.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ using Parameters # lets you have defaults for fields
99
using EllipsisNotation # adds '..' to refer to the rest of array
1010
import ScikitLearn
1111
import StatsBase
12+
include("ConvenienceFunctions.jl")
13+
1214
const sklearn = ScikitLearn
1315

1416
sklearn.@sk_import gaussian_process : GaussianProcessRegressor
@@ -324,19 +326,6 @@ function plot_fit(gprw::Wrap, plt; plot_95 = false, label = nothing)
324326
end
325327
end
326328

327-
################################################################################
328-
# convenience functions ########################################################
329-
################################################################################
330-
const RPAD = 25
331-
332-
function name(name::AbstractString)
333-
return rpad(name * ":", RPAD)
334-
end
335-
336-
function warn(name::AbstractString)
337-
return rpad("WARNING (" * name * "):", RPAD)
338-
end
339-
340329
end # module
341330

342331

src/Histograms.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
module Histograms
2+
"""
3+
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
4+
5+
Functions in this module:
6+
- W1 (2 methods)
7+
8+
"""
9+
10+
import PyCall
11+
include("ConvenienceFunctions.jl")
12+
13+
scsta = PyCall.pyimport("scipy.stats")
14+
15+
################################################################################
16+
# distance functions ###########################################################
17+
################################################################################
18+
"""
19+
Compute the Wasserstein-1 distance between two distributions from their samples
20+
21+
Parameters:
22+
- u_samples: array-like; samples from the 1st distribution
23+
- v_samples: array-like; samples from the 2nd distribution
24+
- normalize: boolean; whether to normalize the distance by 1/(max-min)
25+
26+
Returns:
27+
- w1_uv: number; the Wasserstein-1 distance
28+
"""
29+
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
30+
normalize = true)
31+
L = maximum([u_samples; v_samples]) - minimum([u_samples; v_samples])
32+
return if !normalize
33+
scsta.wasserstein_distance(u_samples, v_samples)
34+
else
35+
scsta.wasserstein_distance(u_samples, v_samples) / L
36+
end
37+
end
38+
39+
"""
40+
Compute the pairwise Wasserstein-1 distances between two sets of distributions
41+
from their samples
42+
43+
Parameters:
44+
- U_samples: matrix-like; samples from distributions (u1, u2, ...)
45+
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
46+
- normalize: boolean; whether to normalize the distances by 1/(max-min)
47+
48+
`U_samples` and `V_samples` should have samples in the 2nd dimension (along
49+
rows) and have the same 1st dimension (same number of rows). If not, the minimum
50+
of the two (minimum number of rows) will be taken.
51+
52+
`normalize` induces *pairwise* normalization, i.e. it max's and min's are
53+
computed for each pair (u_j, v_j) individually.
54+
55+
Returns:
56+
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
57+
w1(u1, v1)
58+
w1(u2, v2)
59+
...
60+
w1(u_K, v_K)
61+
"""
62+
function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
63+
normalize = true)
64+
if size(U_samples, 1) != size(V_samples, 1)
65+
println(warn("W1"), "sizes of U_samples & V_samples don't match; ",
66+
"will use the minimum of the two")
67+
end
68+
K = min(size(U_samples, 1), size(V_samples, 1))
69+
w1_UV = zeros(K)
70+
U_sorted = sort(U_samples[1:K, 1:end], dims = 2)
71+
V_sorted = sort(V_samples[1:K, 1:end], dims = 2)
72+
for k in 1:K
73+
w1_UV[k] = W1(U_sorted[k, 1:end], V_sorted[k, 1:end]; normalize = normalize)
74+
end
75+
return w1_UV
76+
end
77+
78+
end # module
79+
80+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using Test
2+
3+
include("../../src/ConvenienceFunctions.jl")
4+
5+
################################################################################
6+
# unit testing #################################################################
7+
################################################################################
8+
@testset "unit testing" begin
9+
@test isdefined(Main, :RPAD)
10+
@test length(name("a")) == RPAD
11+
@test length(name("a" ^ RPAD)) == (RPAD + 1)
12+
@test length(warn("a")) == RPAD
13+
@test length(warn("a" ^ RPAD)) == (RPAD + 11)
14+
@test isa(name("a"), String)
15+
@test isa(warn("a"), String)
16+
end
17+
println("")
18+
19+

test/Histograms/data/x1_bal.npy

781 KB
Binary file not shown.

test/Histograms/data/x1_dns.npy

7.5 MB
Binary file not shown.

test/Histograms/data/x1_onl.npy

1.14 MB
Binary file not shown.

test/Histograms/runtests.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Test
2+
import NPZ
3+
4+
include("../../src/Histograms.jl")
5+
const Hgm = Histograms
6+
7+
const data_dir = joinpath(@__DIR__, "data")
8+
const x1_bal = NPZ.npzread(joinpath(data_dir, "x1_bal.npy"))
9+
const x1_dns = NPZ.npzread(joinpath(data_dir, "x1_dns.npy"))
10+
const x1_onl = NPZ.npzread(joinpath(data_dir, "x1_onl.npy"))
11+
const w1_dns_bal = 0.03755967829782972
12+
const w1_dns_onl = 0.004489688974663949
13+
const w1_bal_onl = 0.037079734072606625
14+
const w1_dns_bal_unnorm = 0.8190688772401341
15+
16+
################################################################################
17+
# unit testing #################################################################
18+
################################################################################
19+
@testset "unit testing" begin
20+
arr1 = [1, 1, 1, 2, 3, 4, 4, 4]
21+
arr2 = [1, 1, 2, 2, 3, 3, 4, 4, 4]
22+
@test Hgm.W1(arr1, arr2, normalize = false) == 0.25
23+
@test Hgm.W1(arr2, arr1, normalize = false) == 0.25
24+
@test Hgm.W1(arr1, arr2) == Hgm.W1(arr2, arr1)
25+
26+
@test isapprox(Hgm.W1(x1_dns, x1_bal), w1_dns_bal)
27+
@test isapprox(Hgm.W1(x1_dns, x1_onl), w1_dns_onl)
28+
@test isapprox(Hgm.W1(x1_bal, x1_onl), w1_bal_onl)
29+
@test isapprox(Hgm.W1(x1_dns, x1_bal, normalize = false), w1_dns_bal_unnorm)
30+
31+
@test size(Hgm.W1(rand(3,100), rand(3,100))) == (3,)
32+
@test size(Hgm.W1(rand(9,100), rand(3,100))) == (3,)
33+
end
34+
println("")
35+
36+

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ include("neki.jl")
66

77
for submodule in ["L96m",
88
"GPR",
9+
"Histograms",
10+
"ConvenienceFunctions",
911
]
1012

1113
println("Starting tests for $submodule")

0 commit comments

Comments
 (0)