@@ -3,7 +3,10 @@ module Histograms
3
3
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
4
4
5
5
Functions in this module:
6
- - W1 (2 methods)
6
+ - W1 (8 methods)
7
+ - load! (3 methods)
8
+ - plot (2 methods)
9
+ - plot_raw (1 method)
7
10
8
11
"""
9
12
@@ -16,12 +19,33 @@ scsta = PyCall.pyimport("scipy.stats")
16
19
# ###############################################################################
17
20
# HistData struct ##############################################################
18
21
# ###############################################################################
22
+ """
23
+ A simple struct to store samples for empirical PDFs (histograms, distances etc.)
24
+
25
+ Functions that operate on HistData struct:
26
+ - W1 (4 methods)
27
+ - load! (3 methods)
28
+ - plot (2 methods)
29
+ - plot_raw (1 method)
30
+ """
19
31
mutable struct HistData
20
32
samples:: Dict{Symbol, AbstractVecOrMat}
21
33
end
22
34
23
35
HistData () = HistData (Dict ())
24
36
37
+ """
38
+ Load samples into a HistData object under a specific key from a vector
39
+
40
+ Parameters:
41
+ - hd: HistData; groups of samples accessed by keys
42
+ - key: Symbol; key of the samples group to load samples to
43
+ - samples: array-like; samples to load
44
+
45
+ If the `key` group already exists, the samples are appended. If, in addition,
46
+ the samples were in a matrix, they are flattened first, and then the new ones
47
+ are added.
48
+ """
25
49
function load! (hd:: HistData , key:: Symbol , samples:: AbstractVector )
26
50
if haskey (hd. samples, key)
27
51
if isa (hd. samples[key], Matrix)
@@ -34,6 +58,19 @@ function load!(hd::HistData, key::Symbol, samples::AbstractVector)
34
58
end
35
59
end
36
60
61
+ """
62
+ Load samples into a HistData object under a specific key from a matrix
63
+
64
+ Parameters:
65
+ - hd: HistData; groups of samples accessed by keys
66
+ - key: Symbol; key of the samples group to load samples to
67
+ - samples: matrix-like; samples to load
68
+
69
+ If the `key` group already exists, the samples are appended. If, in addition,
70
+ the samples were in a vector, the new ones are flattened first, and then added
71
+ to the old ones. If the samples were in a matrix and row dimensions don't match,
72
+ the minimum of the two dimensions is chosen, the rest is discarded.
73
+ """
37
74
function load! (hd:: HistData , key:: Symbol , samples:: AbstractMatrix )
38
75
if haskey (hd. samples, key)
39
76
if isa (hd. samples[key], Vector)
@@ -53,15 +90,38 @@ function load!(hd::HistData, key::Symbol, samples::AbstractMatrix)
53
90
end
54
91
end
55
92
93
+ """
94
+ Load samples into a HistData object under a specific key from a file
95
+
96
+ Parameters:
97
+ - hd: HistData; groups of samples accessed by keys
98
+ - key: Symbol; key of the samples group to load samples to
99
+ - filename: String; name of a .npy file with samples (vector/matrix)
100
+
101
+ All the rules about vector/matrix interaction from two other methods apply.
102
+ """
56
103
function load! (hd:: HistData , key:: Symbol , filename:: String )
57
104
samples = NPZ. npzread (filename)
58
- if isa (samples, Array)
105
+ if isa (samples, Array) && ndims (samples) <= 2
59
106
load! (hd, key, samples)
60
107
else
61
- throw (error (" load!: " , filename, " is not an Array; abort" ))
108
+ throw (error (" load!: " , filename, " is not a 1- or 2-d Array; abort" ))
62
109
end
63
110
end
64
111
112
+ """
113
+ Plot a histogram of a range of data by a specific key
114
+
115
+ Parameters:
116
+ - hd: HistData; groups of samples accessed by keys
117
+ - plt: a module used for plotting (only PyPlot supported)
118
+ - key: Symbol; key of the samples group to construct histogram from
119
+ - k: Int or UnitRange; if samples are in a matrix, which rows to use
120
+ - kws: dictionary-like; keyword arguments to pass to plotting function
121
+
122
+ If `hd.samples[key]` is a matrix, all the samples from `k` row(s) are combined;
123
+ if it is a vector, `k` is ignored.
124
+ """
65
125
function plot (hd:: HistData , plt, key:: Symbol , k:: Union{Int,UnitRange} ; kws... )
66
126
if length (hd. samples) == 0
67
127
println (warn (" plot" ), " no samples, nothing to plot" )
@@ -80,13 +140,35 @@ function plot(hd::HistData, plt, key::Symbol, k::Union{Int,UnitRange}; kws...)
80
140
end
81
141
82
142
plot_raw (S, plt; kws... )
83
-
84
143
end
85
144
145
+ """
146
+ Plot a histogram of whole data by a specific key
147
+
148
+ Parameters:
149
+ - hd: HistData; groups of samples accessed by keys
150
+ - plt: a module used for plotting (only PyPlot supported)
151
+ - key: Symbol; key of the samples group to construct histogram from
152
+ - kws: dictionary-like; keyword arguments to pass to plotting function
153
+ """
86
154
plot (hd:: HistData , plt, key:: Symbol ; kws... ) =
87
155
plot (hd, plt, key, UnitRange (1 , size (hd. samples[key], 1 )); kws... )
88
156
89
- function plot_raw (S:: Vector , plt; kws... )
157
+ """
158
+ Plot a histogram of samples
159
+
160
+ Parameters:
161
+ - S: array-like; samples to construct histogram from
162
+ - plt: a module used for plotting (only PyPlot supported)
163
+ - kws: dictionary-like; keyword arguments to pass to plotting function
164
+
165
+ The keyword arguments `kws` have precedence over defaults, but if left
166
+ unspecified, these are the defaults:
167
+ bins = "auto"
168
+ histtype = "step"
169
+ density = true
170
+ """
171
+ function plot_raw (S:: AbstractVector , plt; kws... )
90
172
kwargs_local = Dict {Symbol, Any} ()
91
173
kwargs_local[:bins ] = " auto"
92
174
kwargs_local[:histtype ] = " step"
@@ -129,19 +211,19 @@ Parameters:
129
211
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
130
212
- normalize: boolean; whether to normalize the distances by 1/(max-min)
131
213
132
- `U_samples` and `V_samples` should have samples in the 2nd dimension (along
133
- rows) and have the same 1st dimension (same number of rows). If not, the minimum
134
- of the two (minimum number of rows) will be taken.
135
-
136
- `normalize` induces *pairwise* normalization, i.e. it max's and min's are
137
- computed for each pair (u_j, v_j) individually.
138
-
139
214
Returns:
140
215
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
141
216
w1(u1, v1)
142
217
w1(u2, v2)
143
218
...
144
219
w1(u_K, v_K)
220
+
221
+ `U_samples` and `V_samples` should have samples in the 2nd dimension (along
222
+ rows) and have the same 1st dimension (same number of rows). If not, the minimum
223
+ of the two (minimum number of rows) will be taken.
224
+
225
+ `normalize` induces *pairwise* normalization, i.e. it max's and min's are
226
+ computed for each pair (u_j, v_j) individually.
145
227
"""
146
228
function W1 (U_samples:: AbstractMatrix , V_samples:: AbstractMatrix ;
147
229
normalize = true )
@@ -159,6 +241,37 @@ function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
159
241
return w1_UV
160
242
end
161
243
244
+ W1 (U_samples:: AbstractMatrix , v_samples:: AbstractVector ; normalize = true ) =
245
+ W1 (vec (U_samples), v_samples; normalize = normalize)
246
+
247
+ W1 (u_samples:: AbstractVector , V_samples:: AbstractMatrix ; normalize = true ) =
248
+ W1 (u_samples, vec (V_samples); normalize = normalize)
249
+
250
+ """
251
+ Compute pairs of Wasserstein-1 distances between `key` samples and the rest
252
+
253
+ Parameters:
254
+ - hd: HistData; groups of samples accessed by keys
255
+ - key: Symbol; key of the samples group to compare everything else against
256
+ - k: Int or UnitRange; if samples are in a matrix, which rows to use
257
+
258
+ Returns:
259
+ - key2all_combined: Dict{Symbol, Float64}; pairs of W1 distances
260
+
261
+ Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
262
+ If any of the `hd.samples` is a matrix (not a vector) then `k` is used to access
263
+ rows of said matrix, and then samples from these rows are combined together.
264
+ For any of the `hd.samples` that is a vector, `k` is ignored.
265
+
266
+ This function is useful when you have one reference (empirical) distribution and
267
+ want to compare the rest against that "ground truth" distribution.
268
+
269
+ Examples:
270
+ k2a_row1 = W1(hd, :dns, 1)
271
+ K = size(hd.samples[:dns], 1)
272
+ k2a_combined = W1(hd, :dns, 1:K)
273
+ println(k2a_combined[:bal])
274
+ """
162
275
function W1 (hd:: HistData , key:: Symbol , k:: Union{Int,UnitRange} )
163
276
key2all_combined = Dict {Symbol, Float64} ()
164
277
for ki in keys (hd. samples)
@@ -170,34 +283,100 @@ function W1(hd::HistData, key::Symbol, k::Union{Int,UnitRange})
170
283
return key2all_combined
171
284
end
172
285
286
+ """
287
+ Compute a pair of Wasserstein-1 distance between `key1` & `key2` samples
288
+
289
+ Parameters:
290
+ - hd: HistData; groups of samples accessed by keys
291
+ - key1: Symbol; key of the first samples group
292
+ - key2: Symbol; key of the second samples group
293
+ - k: Int or UnitRange; if samples are in a matrix, which rows to use
294
+
295
+ Returns:
296
+ - w1_key1key2: Float64; W1 distance
297
+
298
+ Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
299
+ If either of them is a matrix (not a vector) then `k` is used to access rows of
300
+ said matrix, and then samples from these rows are combined together.
301
+ For vectors, `k` is ignored.
302
+
303
+ Examples:
304
+ w1_dns_bal = W1(hd, :dns, :bal, 1)
305
+ K = size(hd.samples[:dns], 1)
306
+ w1_dns_bal_combined = W1(hd, :dns, :bal, 1:K)
307
+ """
173
308
function W1 (hd:: HistData , key1:: Symbol , key2:: Symbol , k:: Union{Int,UnitRange} )
174
- if isa (hd. samples[key1], Vector)
175
- u = hd. samples[key1]
309
+ u = if isa (hd. samples[key1], Vector)
310
+ hd. samples[key1]
176
311
else
177
- u = vec (hd. samples[key1][k, 1 : end ])
312
+ vec (hd. samples[key1][k, 1 : end ])
178
313
end
179
- if isa (hd. samples[key2], Vector)
180
- v = hd. samples[key2]
314
+ v = if isa (hd. samples[key2], Vector)
315
+ hd. samples[key2]
181
316
else
182
- v = vec (hd. samples[key2][k, 1 : end ])
317
+ vec (hd. samples[key2][k, 1 : end ])
183
318
end
184
319
return W1 (u, v)
185
320
end
186
321
322
+ """
323
+ Compute vectors of Wasserstein-1 distances between `key` samples and the rest
324
+
325
+ Parameters:
326
+ - hd: HistData; groups of samples accessed by keys
327
+ - key: Symbol; key of the samples group to compare everything else against
328
+
329
+ Returns:
330
+ - key2all_vectorized: Dict{Symbol, Union{Vector{Float64}, Float64}};
331
+ either vectors or pairs of W1 distances
332
+
333
+ Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
334
+ For each pair of samples (`key` and something else) where both groups of samples
335
+ are in a matrix, the returned value will be a vector (corresponding to rows of
336
+ the matrices); for each pair where at least one of the groups is a vector, the
337
+ returned value will be a Float64, and all samples from a matrix are combined.
338
+
339
+ This function is useful when you have one reference (empirical) distribution and
340
+ want to compare the rest against that "ground truth" distribution.
341
+
342
+ Examples:
343
+ k2a_vectorized = W1(hd, :dns)
344
+ println(k2a_vectorized[:onl] .> 0.01)
345
+ """
187
346
function W1 (hd:: HistData , key:: Symbol )
188
- key2all_vec = Dict {Symbol, Vector{Float64}} ()
347
+ key2all_vectorized = Dict {Symbol, Union{ Vector{Float64}, Float64}} ()
189
348
for ki in keys (hd. samples)
190
349
if ki == key
191
350
continue
192
351
end
193
- key2all_vec [ki] = W1 (hd, ki, key)
352
+ key2all_vectorized [ki] = W1 (hd, ki, key)
194
353
end
195
- return key2all_vec
354
+ return key2all_vectorized
196
355
end
197
356
198
- function W1 (hd:: HistData , key1:: Symbol , key2:: Symbol )
199
- return W1 (hd. samples[key1], hd. samples[key2])
200
- end
357
+ """
358
+ Compute a vector of Wasserstein-1 distances between `key1` & `key2` samples
359
+
360
+ Parameters:
361
+ - hd: HistData; groups of samples accessed by keys
362
+ - key1: Symbol; key of the first samples group
363
+ - key2: Symbol; key of the second samples group
364
+
365
+ Returns:
366
+ - w1_key1key2: Union{Vector{Float64}, Float64}; W1 distance
367
+
368
+ Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
369
+
370
+ If both are matrices, the returned value will be a vector (corresponding to rows
371
+ of the matrices); if at least one of them is a vector, the returned value will
372
+ be a Float64, and all samples from a matrix are combined.
373
+
374
+ Examples:
375
+ w1_dns_bal = W1(hd, :dns, :bal)
376
+ println(w1_dns_bal .> 0.01)
377
+ """
378
+ W1 (hd:: HistData , key1:: Symbol , key2:: Symbol ) =
379
+ W1 (hd. samples[key1], hd. samples[key2])
201
380
202
381
end # module
203
382
0 commit comments