Skip to content

Commit 43d87c3

Browse files
committed
Add 'Histograms' tests; fix small bugs
1 parent bd5ace4 commit 43d87c3

File tree

6 files changed

+49
-10
lines changed

6 files changed

+49
-10
lines changed

src/Histograms.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ module Histograms
33
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
44
55
Functions in this module:
6-
- wasserstein (2 methods)
6+
- W1 (2 methods)
77
88
"""
99

1010
import PyCall
11+
include("ConvenienceFunctions.jl")
12+
1113
scsta = PyCall.pyimport("scipy.stats")
1214

1315
################################################################################
@@ -22,9 +24,9 @@ Parameters:
2224
- normalize: boolean; whether to normalize the distance by 1/(max-min)
2325
2426
Returns:
25-
- w1: number; the Wasserstein-1 distance
27+
- w1_uv: number; the Wasserstein-1 distance
2628
"""
27-
function wasserstein(u_samples::AbstractVector, v_samples::AbstractVector;
29+
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
2830
normalize = true)
2931
L = maximum([u_samples; v_samples]) - minimum([u_samples; v_samples])
3032
return if !normalize
@@ -51,27 +53,26 @@ of the two (minimum number of rows) will be taken.
5153
computed for each pair (u_j, v_j) individually.
5254
5355
Returns:
54-
- w1: array-like; the pairwise Wasserstein-1 distances:
56+
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
5557
w1(u1, v1)
5658
w1(u2, v2)
5759
...
5860
w1(u_K, v_K)
5961
"""
60-
function wasserstein(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
62+
function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
6163
normalize = true)
6264
if size(U_samples, 1) != size(V_samples, 1)
63-
println(warn("wasserstein"), "sizes of U_samples & V_samples don't match; ",
65+
println(warn("W1"), "sizes of U_samples & V_samples don't match; ",
6466
"will use the minimum of the two")
6567
end
6668
K = min(size(U_samples, 1), size(V_samples, 1))
67-
w1 = zeros(K)
69+
w1_UV = zeros(K)
6870
U_sorted = sort(U_samples[1:K, 1:end], dims = 2)
6971
V_sorted = sort(V_samples[1:K, 1:end], dims = 2)
7072
for k in 1:K
71-
w1[k] = wasserstein(U_sorted[k, 1:end], V_sorted[k, 1:end];
72-
normalize = normalize)
73+
w1_UV[k] = W1(U_sorted[k, 1:end], V_sorted[k, 1:end]; normalize = normalize)
7374
end
74-
return w1
75+
return w1_UV
7576
end
7677

7778
end # module

test/Histograms/data/x1_bal.npy

781 KB
Binary file not shown.

test/Histograms/data/x1_dns.npy

7.5 MB
Binary file not shown.

test/Histograms/data/x1_onl.npy

1.14 MB
Binary file not shown.

test/Histograms/runtests.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Test
2+
import NPZ
3+
4+
include("../../src/Histograms.jl")
5+
const Hgm = Histograms
6+
7+
const data_dir = joinpath(@__DIR__, "data")
8+
const x1_bal = NPZ.npzread(joinpath(data_dir, "x1_bal.npy"))
9+
const x1_dns = NPZ.npzread(joinpath(data_dir, "x1_dns.npy"))
10+
const x1_onl = NPZ.npzread(joinpath(data_dir, "x1_onl.npy"))
11+
const w1_dns_bal = 0.03755967829782972
12+
const w1_dns_onl = 0.004489688974663949
13+
const w1_bal_onl = 0.037079734072606625
14+
const w1_dns_bal_unnorm = 0.8190688772401341
15+
16+
################################################################################
17+
# unit testing #################################################################
18+
################################################################################
19+
@testset "unit testing" begin
20+
arr1 = [1, 1, 1, 2, 3, 4, 4, 4]
21+
arr2 = [1, 1, 2, 2, 3, 3, 4, 4, 4]
22+
@test Hgm.W1(arr1, arr2, normalize = false) == 0.25
23+
@test Hgm.W1(arr2, arr1, normalize = false) == 0.25
24+
@test Hgm.W1(arr1, arr2) == Hgm.W1(arr2, arr1)
25+
26+
@test isapprox(Hgm.W1(x1_dns, x1_bal), w1_dns_bal)
27+
@test isapprox(Hgm.W1(x1_dns, x1_onl), w1_dns_onl)
28+
@test isapprox(Hgm.W1(x1_bal, x1_onl), w1_bal_onl)
29+
@test isapprox(Hgm.W1(x1_dns, x1_bal, normalize = false), w1_dns_bal_unnorm)
30+
31+
@test size(Hgm.W1(rand(3,100), rand(3,100))) == (3,)
32+
@test size(Hgm.W1(rand(9,100), rand(3,100))) == (3,)
33+
end
34+
println("")
35+
36+

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ include("neki.jl")
66

77
for submodule in ["L96m",
88
"GPR",
9+
"Histograms",
10+
"ConvenienceFunctions",
911
]
1012

1113
println("Starting tests for $submodule")

0 commit comments

Comments
 (0)