@@ -3,15 +3,181 @@ module Histograms
3
3
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
4
4
5
5
Functions in this module:
6
- - W1 (2 methods)
6
+ - W1 (8 methods)
7
+ - load! (3 methods)
8
+ - plot (2 methods)
9
+ - plot_raw (1 method)
7
10
8
11
"""
9
12
10
13
import PyCall
14
+ import NPZ
11
15
include (" ConvenienceFunctions.jl" )
12
16
13
17
scsta = PyCall. pyimport (" scipy.stats" )
14
18
19
+ # ###############################################################################
20
+ # HistData struct ##############################################################
21
+ # ###############################################################################
22
+ """
23
+ A simple struct to store samples for empirical PDFs (histograms, distances etc.)
24
+
25
+ Functions that operate on HistData struct:
26
+ - W1 (4 methods)
27
+ - load! (3 methods)
28
+ - plot (2 methods)
29
+ - plot_raw (1 method)
30
+ """
31
+ mutable struct HistData
32
+ samples:: Dict{Symbol, AbstractVecOrMat}
33
+ end
34
+
35
+ HistData () = HistData (Dict ())
36
+
37
+ """
38
+ Load samples into a HistData object under a specific key from a vector
39
+
40
+ Parameters:
41
+ - hd: HistData; groups of samples accessed by keys
42
+ - key: Symbol; key of the samples group to load samples to
43
+ - samples: array-like; samples to load
44
+
45
+ If the `key` group already exists, the samples are appended. If, in addition,
46
+ the samples were in a matrix, they are flattened first, and then the new ones
47
+ are added.
48
+ """
49
+ function load! (hd:: HistData , key:: Symbol , samples:: AbstractVector )
50
+ if haskey (hd. samples, key)
51
+ if isa (hd. samples[key], Matrix)
52
+ println (warn (" load!" ),
53
+ " hd.samples[:" , key, " ] is a Matrix; using vec(hd.samples)" )
54
+ end
55
+ hd. samples[key] = vcat (vec (hd. samples[key]), samples)
56
+ else
57
+ hd. samples[key] = samples
58
+ end
59
+ end
60
+
61
+ """
62
+ Load samples into a HistData object under a specific key from a matrix
63
+
64
+ Parameters:
65
+ - hd: HistData; groups of samples accessed by keys
66
+ - key: Symbol; key of the samples group to load samples to
67
+ - samples: matrix-like; samples to load
68
+
69
+ If the `key` group already exists, the samples are appended. If, in addition,
70
+ the samples were in a vector, the new ones are flattened first, and then added
71
+ to the old ones. If the samples were in a matrix and row dimensions don't match,
72
+ the minimum of the two dimensions is chosen, the rest is discarded.
73
+ """
74
+ function load! (hd:: HistData , key:: Symbol , samples:: AbstractMatrix )
75
+ if haskey (hd. samples, key)
76
+ if isa (hd. samples[key], Vector)
77
+ println (warn (" load!" ),
78
+ " hd.samples[:" , key, " ] is a Vector; using vec(samples)" )
79
+ hd. samples[key] = vcat (hd. samples[key], vec (samples))
80
+ elseif size (hd. samples[key], 1 ) != size (samples, 1 )
81
+ println (warn (" load!" ), " sizes of hd.samples & samples don't match; " ,
82
+ " squishing down to match the minimum of the two" )
83
+ K = min (size (hd. samples[key], 1 ), size (samples,1 ))
84
+ hd. samples[key] = hcat (hd. samples[key][1 : K, 1 : end ], samples[1 : K, 1 : end ])
85
+ else
86
+ hd. samples[key] = hcat (hd. samples[key], samples)
87
+ end
88
+ else
89
+ hd. samples[key] = samples
90
+ end
91
+ end
92
+
93
+ """
94
+ Load samples into a HistData object under a specific key from a file
95
+
96
+ Parameters:
97
+ - hd: HistData; groups of samples accessed by keys
98
+ - key: Symbol; key of the samples group to load samples to
99
+ - filename: String; name of a .npy file with samples (vector/matrix)
100
+
101
+ All the rules about vector/matrix interaction from two other methods apply.
102
+ """
103
+ function load! (hd:: HistData , key:: Symbol , filename:: String )
104
+ samples = NPZ. npzread (filename)
105
+ if isa (samples, Array) && ndims (samples) <= 2
106
+ load! (hd, key, samples)
107
+ else
108
+ throw (error (" load!: " , filename, " is not a 1- or 2-d Array; abort" ))
109
+ end
110
+ end
111
+
112
+ """
113
+ Plot a histogram of a range of data by a specific key
114
+
115
+ Parameters:
116
+ - hd: HistData; groups of samples accessed by keys
117
+ - plt: a module used for plotting (only PyPlot supported)
118
+ - key: Symbol; key of the samples group to construct histogram from
119
+ - k: Int or UnitRange; if samples are in a matrix, which rows to use
120
+ - kws: dictionary-like; keyword arguments to pass to plotting function
121
+
122
+ If `hd.samples[key]` is a matrix, all the samples from `k` row(s) are combined;
123
+ if it is a vector, `k` is ignored.
124
+ """
125
+ function plot (hd:: HistData , plt, key:: Symbol , k:: Union{Int,UnitRange} ; kws... )
126
+ if length (hd. samples) == 0
127
+ println (warn (" plot" ), " no samples, nothing to plot" )
128
+ return
129
+ end
130
+ is_pyplot = (Symbol (plt) == :PyPlot )
131
+ if ! is_pyplot
132
+ println (warn (" plot" ), " only PyPlot is supported; not plotting" )
133
+ return
134
+ end
135
+
136
+ if isa (hd. samples[key], Vector)
137
+ S = hd. samples[key]
138
+ else
139
+ S = vec (hd. samples[key][k, 1 : end ])
140
+ end
141
+
142
+ plot_raw (S, plt; kws... )
143
+ end
144
+
145
+ """
146
+ Plot a histogram of whole data by a specific key
147
+
148
+ Parameters:
149
+ - hd: HistData; groups of samples accessed by keys
150
+ - plt: a module used for plotting (only PyPlot supported)
151
+ - key: Symbol; key of the samples group to construct histogram from
152
+ - kws: dictionary-like; keyword arguments to pass to plotting function
153
+ """
154
+ plot (hd:: HistData , plt, key:: Symbol ; kws... ) =
155
+ plot (hd, plt, key, UnitRange (1 , size (hd. samples[key], 1 )); kws... )
156
+
157
+ """
158
+ Plot a histogram of samples
159
+
160
+ Parameters:
161
+ - S: array-like; samples to construct histogram from
162
+ - plt: a module used for plotting (only PyPlot supported)
163
+ - kws: dictionary-like; keyword arguments to pass to plotting function
164
+
165
+ The keyword arguments `kws` have precedence over defaults, but if left
166
+ unspecified, these are the defaults:
167
+ bins = "auto"
168
+ histtype = "step"
169
+ density = true
170
+ """
171
+ function plot_raw (S:: AbstractVector , plt; kws... )
172
+ kwargs_local = Dict {Symbol, Any} ()
173
+ kwargs_local[:bins ] = " auto"
174
+ kwargs_local[:histtype ] = " step"
175
+ kwargs_local[:density ] = true
176
+
177
+ # the order of merge is important! kws have higher priority
178
+ plt. hist (S; merge (kwargs_local, kws)... )
179
+ end
180
+
15
181
# ###############################################################################
16
182
# distance functions ###########################################################
17
183
# ###############################################################################
@@ -27,11 +193,13 @@ Returns:
27
193
- w1_uv: number; the Wasserstein-1 distance
28
194
"""
29
195
function W1 (u_samples:: AbstractVector , v_samples:: AbstractVector ;
30
- normalize = true )
31
- L = maximum ([u_samples; v_samples]) - minimum ([u_samples; v_samples])
196
+ normalize = false )
32
197
return if ! normalize
33
198
scsta. wasserstein_distance (u_samples, v_samples)
34
199
else
200
+ u_m, u_M = extrema (u_samples)
201
+ v_m, v_M = extrema (v_samples)
202
+ L = max (u_M, v_M) - min (u_m, v_m)
35
203
scsta. wasserstein_distance (u_samples, v_samples) / L
36
204
end
37
205
end
@@ -45,22 +213,22 @@ Parameters:
45
213
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
46
214
- normalize: boolean; whether to normalize the distances by 1/(max-min)
47
215
48
- `U_samples` and `V_samples` should have samples in the 2nd dimension (along
49
- rows) and have the same 1st dimension (same number of rows). If not, the minimum
50
- of the two (minimum number of rows) will be taken.
51
-
52
- `normalize` induces *pairwise* normalization, i.e. it max's and min's are
53
- computed for each pair (u_j, v_j) individually.
54
-
55
216
Returns:
56
217
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
57
218
w1(u1, v1)
58
219
w1(u2, v2)
59
220
...
60
221
w1(u_K, v_K)
222
+
223
+ `U_samples` and `V_samples` should have samples in the 2nd dimension (along
224
+ rows) and have the same 1st dimension (same number of rows). If not, the minimum
225
+ of the two (minimum number of rows) will be taken.
226
+
227
+ `normalize` induces *pairwise* normalization, i.e. it max's and min's are
228
+ computed for each pair (u_j, v_j) individually.
61
229
"""
62
230
function W1 (U_samples:: AbstractMatrix , V_samples:: AbstractMatrix ;
63
- normalize = true )
231
+ normalize = false )
64
232
if size (U_samples, 1 ) != size (V_samples, 1 )
65
233
println (warn (" W1" ), " sizes of U_samples & V_samples don't match; " ,
66
234
" will use the minimum of the two" )
@@ -75,6 +243,143 @@ function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
75
243
return w1_UV
76
244
end
77
245
246
+ W1 (U_samples:: AbstractMatrix , v_samples:: AbstractVector ; normalize = false ) =
247
+ W1 (vec (U_samples), v_samples; normalize = normalize)
248
+
249
+ W1 (u_samples:: AbstractVector , V_samples:: AbstractMatrix ; normalize = false ) =
250
+ W1 (u_samples, vec (V_samples); normalize = normalize)
251
+
252
+ """
253
+ Compute pairs of Wasserstein-1 distances between `key` samples and the rest
254
+
255
+ Parameters:
256
+ - hd: HistData; groups of samples accessed by keys
257
+ - key: Symbol; key of the samples group to compare everything else against
258
+ - k: Int or UnitRange; if samples are in a matrix, which rows to use
259
+
260
+ Returns:
261
+ - key2all_combined: Dict{Symbol, Float64}; pairs of W1 distances
262
+
263
+ Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
264
+ If any of the `hd.samples` is a matrix (not a vector) then `k` is used to access
265
+ rows of said matrix, and then samples from these rows are combined together.
266
+ For any of the `hd.samples` that is a vector, `k` is ignored.
267
+
268
+ This function is useful when you have one reference (empirical) distribution and
269
+ want to compare the rest against that "ground truth" distribution.
270
+
271
+ Examples:
272
+ k2a_row1 = W1(hd, :dns, 1)
273
+ K = size(hd.samples[:dns], 1)
274
+ k2a_combined = W1(hd, :dns, 1:K)
275
+ println(k2a_combined[:bal])
276
+ """
277
+ function W1 (hd:: HistData , key:: Symbol , k:: Union{Int,UnitRange} )
278
+ key2all_combined = Dict {Symbol, Float64} ()
279
+ for ki in keys (hd. samples)
280
+ if ki == key
281
+ continue
282
+ end
283
+ key2all_combined[ki] = W1 (hd, ki, key, k)
284
+ end
285
+ return key2all_combined
286
+ end
287
+
288
+ """
289
+ Compute a pair of Wasserstein-1 distance between `key1` & `key2` samples
290
+
291
+ Parameters:
292
+ - hd: HistData; groups of samples accessed by keys
293
+ - key1: Symbol; key of the first samples group
294
+ - key2: Symbol; key of the second samples group
295
+ - k: Int or UnitRange; if samples are in a matrix, which rows to use
296
+
297
+ Returns:
298
+ - w1_key1key2: Float64; W1 distance
299
+
300
+ Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
301
+ If either of them is a matrix (not a vector) then `k` is used to access rows of
302
+ said matrix, and then samples from these rows are combined together.
303
+ For vectors, `k` is ignored.
304
+
305
+ Examples:
306
+ w1_dns_bal = W1(hd, :dns, :bal, 1)
307
+ K = size(hd.samples[:dns], 1)
308
+ w1_dns_bal_combined = W1(hd, :dns, :bal, 1:K)
309
+ """
310
+ function W1 (hd:: HistData , key1:: Symbol , key2:: Symbol , k:: Union{Int,UnitRange} )
311
+ u = if isa (hd. samples[key1], Vector)
312
+ hd. samples[key1]
313
+ else
314
+ vec (hd. samples[key1][k, 1 : end ])
315
+ end
316
+ v = if isa (hd. samples[key2], Vector)
317
+ hd. samples[key2]
318
+ else
319
+ vec (hd. samples[key2][k, 1 : end ])
320
+ end
321
+ return W1 (u, v)
322
+ end
323
+
324
+ """
325
+ Compute vectors of Wasserstein-1 distances between `key` samples and the rest
326
+
327
+ Parameters:
328
+ - hd: HistData; groups of samples accessed by keys
329
+ - key: Symbol; key of the samples group to compare everything else against
330
+
331
+ Returns:
332
+ - key2all_vectorized: Dict{Symbol, Union{Vector{Float64}, Float64}};
333
+ either vectors or pairs of W1 distances
334
+
335
+ Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
336
+ For each pair of samples (`key` and something else) where both groups of samples
337
+ are in a matrix, the returned value will be a vector (corresponding to rows of
338
+ the matrices); for each pair where at least one of the groups is a vector, the
339
+ returned value will be a Float64, and all samples from a matrix are combined.
340
+
341
+ This function is useful when you have one reference (empirical) distribution and
342
+ want to compare the rest against that "ground truth" distribution.
343
+
344
+ Examples:
345
+ k2a_vectorized = W1(hd, :dns)
346
+ println(k2a_vectorized[:onl] .> 0.01)
347
+ """
348
+ function W1 (hd:: HistData , key:: Symbol )
349
+ key2all_vectorized = Dict {Symbol, Union{Vector{Float64}, Float64}} ()
350
+ for ki in keys (hd. samples)
351
+ if ki == key
352
+ continue
353
+ end
354
+ key2all_vectorized[ki] = W1 (hd, ki, key)
355
+ end
356
+ return key2all_vectorized
357
+ end
358
+
359
+ """
360
+ Compute a vector of Wasserstein-1 distances between `key1` & `key2` samples
361
+
362
+ Parameters:
363
+ - hd: HistData; groups of samples accessed by keys
364
+ - key1: Symbol; key of the first samples group
365
+ - key2: Symbol; key of the second samples group
366
+
367
+ Returns:
368
+ - w1_key1key2: Union{Vector{Float64}, Float64}; W1 distance
369
+
370
+ Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
371
+
372
+ If both are matrices, the returned value will be a vector (corresponding to rows
373
+ of the matrices); if at least one of them is a vector, the returned value will
374
+ be a Float64, and all samples from a matrix are combined.
375
+
376
+ Examples:
377
+ w1_dns_bal = W1(hd, :dns, :bal)
378
+ println(w1_dns_bal .> 0.01)
379
+ """
380
+ W1 (hd:: HistData , key1:: Symbol , key2:: Symbol ) =
381
+ W1 (hd. samples[key1], hd. samples[key2])
382
+
78
383
end # module
79
384
80
385
0 commit comments