Skip to content

Commit 87d70eb

Browse files
committed
Add docstrings; tweak a little
1 parent a393f23 commit 87d70eb

File tree

1 file changed

+203
-24
lines changed

1 file changed

+203
-24
lines changed

src/Histograms.jl

Lines changed: 203 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ module Histograms
33
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
44
55
Functions in this module:
6-
- W1 (2 methods)
6+
- W1 (8 methods)
7+
- load! (3 methods)
8+
- plot (2 methods)
9+
- plot_raw (1 method)
710
811
"""
912

@@ -16,12 +19,33 @@ scsta = PyCall.pyimport("scipy.stats")
1619
################################################################################
1720
# HistData struct ##############################################################
1821
################################################################################
22+
"""
23+
A simple struct to store samples for empirical PDFs (histograms, distances etc.)
24+
25+
Functions that operate on HistData struct:
26+
- W1 (4 methods)
27+
- load! (3 methods)
28+
- plot (2 methods)
29+
- plot_raw (1 method)
30+
"""
1931
mutable struct HistData
2032
samples::Dict{Symbol, AbstractVecOrMat}
2133
end
2234

2335
HistData() = HistData(Dict())
2436

37+
"""
38+
Load samples into a HistData object under a specific key from a vector
39+
40+
Parameters:
41+
- hd: HistData; groups of samples accessed by keys
42+
- key: Symbol; key of the samples group to load samples to
43+
- samples: array-like; samples to load
44+
45+
If the `key` group already exists, the samples are appended. If, in addition,
46+
the samples were in a matrix, they are flattened first, and then the new ones
47+
are added.
48+
"""
2549
function load!(hd::HistData, key::Symbol, samples::AbstractVector)
2650
if haskey(hd.samples, key)
2751
if isa(hd.samples[key], Matrix)
@@ -34,6 +58,19 @@ function load!(hd::HistData, key::Symbol, samples::AbstractVector)
3458
end
3559
end
3660

61+
"""
62+
Load samples into a HistData object under a specific key from a matrix
63+
64+
Parameters:
65+
- hd: HistData; groups of samples accessed by keys
66+
- key: Symbol; key of the samples group to load samples to
67+
- samples: matrix-like; samples to load
68+
69+
If the `key` group already exists, the samples are appended. If, in addition,
70+
the samples were in a vector, the new ones are flattened first, and then added
71+
to the old ones. If the samples were in a matrix and row dimensions don't match,
72+
the minimum of the two dimensions is chosen, the rest is discarded.
73+
"""
3774
function load!(hd::HistData, key::Symbol, samples::AbstractMatrix)
3875
if haskey(hd.samples, key)
3976
if isa(hd.samples[key], Vector)
@@ -53,15 +90,38 @@ function load!(hd::HistData, key::Symbol, samples::AbstractMatrix)
5390
end
5491
end
5592

93+
"""
94+
Load samples into a HistData object under a specific key from a file
95+
96+
Parameters:
97+
- hd: HistData; groups of samples accessed by keys
98+
- key: Symbol; key of the samples group to load samples to
99+
- filename: String; name of a .npy file with samples (vector/matrix)
100+
101+
All the rules about vector/matrix interaction from two other methods apply.
102+
"""
56103
function load!(hd::HistData, key::Symbol, filename::String)
57104
samples = NPZ.npzread(filename)
58-
if isa(samples, Array)
105+
if isa(samples, Array) && ndims(samples) <= 2
59106
load!(hd, key, samples)
60107
else
61-
throw(error("load!: ", filename, " is not an Array; abort"))
108+
throw(error("load!: ", filename, " is not a 1- or 2-d Array; abort"))
62109
end
63110
end
64111

112+
"""
113+
Plot a histogram of a range of data by a specific key
114+
115+
Parameters:
116+
- hd: HistData; groups of samples accessed by keys
117+
- plt: a module used for plotting (only PyPlot supported)
118+
- key: Symbol; key of the samples group to construct histogram from
119+
- k: Int or UnitRange; if samples are in a matrix, which rows to use
120+
- kws: dictionary-like; keyword arguments to pass to plotting function
121+
122+
If `hd.samples[key]` is a matrix, all the samples from `k` row(s) are combined;
123+
if it is a vector, `k` is ignored.
124+
"""
65125
function plot(hd::HistData, plt, key::Symbol, k::Union{Int,UnitRange}; kws...)
66126
if length(hd.samples) == 0
67127
println(warn("plot"), "no samples, nothing to plot")
@@ -80,13 +140,35 @@ function plot(hd::HistData, plt, key::Symbol, k::Union{Int,UnitRange}; kws...)
80140
end
81141

82142
plot_raw(S, plt; kws...)
83-
84143
end
85144

145+
"""
146+
Plot a histogram of whole data by a specific key
147+
148+
Parameters:
149+
- hd: HistData; groups of samples accessed by keys
150+
- plt: a module used for plotting (only PyPlot supported)
151+
- key: Symbol; key of the samples group to construct histogram from
152+
- kws: dictionary-like; keyword arguments to pass to plotting function
153+
"""
86154
plot(hd::HistData, plt, key::Symbol; kws...) =
87155
plot(hd, plt, key, UnitRange(1, size(hd.samples[key], 1)); kws...)
88156

89-
function plot_raw(S::Vector, plt; kws...)
157+
"""
158+
Plot a histogram of samples
159+
160+
Parameters:
161+
- S: array-like; samples to construct histogram from
162+
- plt: a module used for plotting (only PyPlot supported)
163+
- kws: dictionary-like; keyword arguments to pass to plotting function
164+
165+
The keyword arguments `kws` have precedence over defaults, but if left
166+
unspecified, these are the defaults:
167+
bins = "auto"
168+
histtype = "step"
169+
density = true
170+
"""
171+
function plot_raw(S::AbstractVector, plt; kws...)
90172
kwargs_local = Dict{Symbol, Any}()
91173
kwargs_local[:bins] = "auto"
92174
kwargs_local[:histtype] = "step"
@@ -129,19 +211,19 @@ Parameters:
129211
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
130212
- normalize: boolean; whether to normalize the distances by 1/(max-min)
131213
132-
`U_samples` and `V_samples` should have samples in the 2nd dimension (along
133-
rows) and have the same 1st dimension (same number of rows). If not, the minimum
134-
of the two (minimum number of rows) will be taken.
135-
136-
`normalize` induces *pairwise* normalization, i.e. it max's and min's are
137-
computed for each pair (u_j, v_j) individually.
138-
139214
Returns:
140215
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
141216
w1(u1, v1)
142217
w1(u2, v2)
143218
...
144219
w1(u_K, v_K)
220+
221+
`U_samples` and `V_samples` should have samples in the 2nd dimension (along
222+
rows) and have the same 1st dimension (same number of rows). If not, the minimum
223+
of the two (minimum number of rows) will be taken.
224+
225+
`normalize` induces *pairwise* normalization, i.e. it max's and min's are
226+
computed for each pair (u_j, v_j) individually.
145227
"""
146228
function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
147229
normalize = true)
@@ -159,6 +241,37 @@ function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
159241
return w1_UV
160242
end
161243

244+
W1(U_samples::AbstractMatrix, v_samples::AbstractVector; normalize = true) =
245+
W1(vec(U_samples), v_samples; normalize = normalize)
246+
247+
W1(u_samples::AbstractVector, V_samples::AbstractMatrix; normalize = true) =
248+
W1(u_samples, vec(V_samples); normalize = normalize)
249+
250+
"""
251+
Compute pairs of Wasserstein-1 distances between `key` samples and the rest
252+
253+
Parameters:
254+
- hd: HistData; groups of samples accessed by keys
255+
- key: Symbol; key of the samples group to compare everything else against
256+
- k: Int or UnitRange; if samples are in a matrix, which rows to use
257+
258+
Returns:
259+
- key2all_combined: Dict{Symbol, Float64}; pairs of W1 distances
260+
261+
Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
262+
If any of the `hd.samples` is a matrix (not a vector) then `k` is used to access
263+
rows of said matrix, and then samples from these rows are combined together.
264+
For any of the `hd.samples` that is a vector, `k` is ignored.
265+
266+
This function is useful when you have one reference (empirical) distribution and
267+
want to compare the rest against that "ground truth" distribution.
268+
269+
Examples:
270+
k2a_row1 = W1(hd, :dns, 1)
271+
K = size(hd.samples[:dns], 1)
272+
k2a_combined = W1(hd, :dns, 1:K)
273+
println(k2a_combined[:bal])
274+
"""
162275
function W1(hd::HistData, key::Symbol, k::Union{Int,UnitRange})
163276
key2all_combined = Dict{Symbol, Float64}()
164277
for ki in keys(hd.samples)
@@ -170,34 +283,100 @@ function W1(hd::HistData, key::Symbol, k::Union{Int,UnitRange})
170283
return key2all_combined
171284
end
172285

286+
"""
287+
Compute a pair of Wasserstein-1 distance between `key1` & `key2` samples
288+
289+
Parameters:
290+
- hd: HistData; groups of samples accessed by keys
291+
- key1: Symbol; key of the first samples group
292+
- key2: Symbol; key of the second samples group
293+
- k: Int or UnitRange; if samples are in a matrix, which rows to use
294+
295+
Returns:
296+
- w1_key1key2: Float64; W1 distance
297+
298+
Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
299+
If either of them is a matrix (not a vector) then `k` is used to access rows of
300+
said matrix, and then samples from these rows are combined together.
301+
For vectors, `k` is ignored.
302+
303+
Examples:
304+
w1_dns_bal = W1(hd, :dns, :bal, 1)
305+
K = size(hd.samples[:dns], 1)
306+
w1_dns_bal_combined = W1(hd, :dns, :bal, 1:K)
307+
"""
173308
function W1(hd::HistData, key1::Symbol, key2::Symbol, k::Union{Int,UnitRange})
174-
if isa(hd.samples[key1], Vector)
175-
u = hd.samples[key1]
309+
u = if isa(hd.samples[key1], Vector)
310+
hd.samples[key1]
176311
else
177-
u = vec(hd.samples[key1][k, 1:end])
312+
vec(hd.samples[key1][k, 1:end])
178313
end
179-
if isa(hd.samples[key2], Vector)
180-
v = hd.samples[key2]
314+
v = if isa(hd.samples[key2], Vector)
315+
hd.samples[key2]
181316
else
182-
v = vec(hd.samples[key2][k, 1:end])
317+
vec(hd.samples[key2][k, 1:end])
183318
end
184319
return W1(u, v)
185320
end
186321

322+
"""
323+
Compute vectors of Wasserstein-1 distances between `key` samples and the rest
324+
325+
Parameters:
326+
- hd: HistData; groups of samples accessed by keys
327+
- key: Symbol; key of the samples group to compare everything else against
328+
329+
Returns:
330+
- key2all_vectorized: Dict{Symbol, Union{Vector{Float64}, Float64}};
331+
either vectors or pairs of W1 distances
332+
333+
Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
334+
For each pair of samples (`key` and something else) where both groups of samples
335+
are in a matrix, the returned value will be a vector (corresponding to rows of
336+
the matrices); for each pair where at least one of the groups is a vector, the
337+
returned value will be a Float64, and all samples from a matrix are combined.
338+
339+
This function is useful when you have one reference (empirical) distribution and
340+
want to compare the rest against that "ground truth" distribution.
341+
342+
Examples:
343+
k2a_vectorized = W1(hd, :dns)
344+
println(k2a_vectorized[:onl] .> 0.01)
345+
"""
187346
function W1(hd::HistData, key::Symbol)
188-
key2all_vec = Dict{Symbol, Vector{Float64}}()
347+
key2all_vectorized = Dict{Symbol, Union{Vector{Float64}, Float64}}()
189348
for ki in keys(hd.samples)
190349
if ki == key
191350
continue
192351
end
193-
key2all_vec[ki] = W1(hd, ki, key)
352+
key2all_vectorized[ki] = W1(hd, ki, key)
194353
end
195-
return key2all_vec
354+
return key2all_vectorized
196355
end
197356

198-
function W1(hd::HistData, key1::Symbol, key2::Symbol)
199-
return W1(hd.samples[key1], hd.samples[key2])
200-
end
357+
"""
358+
Compute a vector of Wasserstein-1 distances between `key1` & `key2` samples
359+
360+
Parameters:
361+
- hd: HistData; groups of samples accessed by keys
362+
- key1: Symbol; key of the first samples group
363+
- key2: Symbol; key of the second samples group
364+
365+
Returns:
366+
- w1_key1key2: Union{Vector{Float64}, Float64}; W1 distance
367+
368+
Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
369+
370+
If both are matrices, the returned value will be a vector (corresponding to rows
371+
of the matrices); if at least one of them is a vector, the returned value will
372+
be a Float64, and all samples from a matrix are combined.
373+
374+
Examples:
375+
w1_dns_bal = W1(hd, :dns, :bal)
376+
println(w1_dns_bal .> 0.01)
377+
"""
378+
W1(hd::HistData, key1::Symbol, key2::Symbol) =
379+
W1(hd.samples[key1], hd.samples[key2])
201380

202381
end # module
203382

0 commit comments

Comments
 (0)