Skip to content

Commit 49828b2

Browse files
authored
Merge pull request #11 from climate-machine/histograms
Implement 'HistData' struct for convenient handling of timeseries
2 parents e622818 + edb1365 commit 49828b2

File tree

4 files changed

+369
-21
lines changed

4 files changed

+369
-21
lines changed

src/Histograms.jl

Lines changed: 316 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,181 @@ module Histograms
33
This module is mostly a convenient wrapper of Python functions (numpy, scipy).
44
55
Functions in this module:
6-
- W1 (2 methods)
6+
- W1 (8 methods)
7+
- load! (3 methods)
8+
- plot (2 methods)
9+
- plot_raw (1 method)
710
811
"""
912

1013
import PyCall
14+
import NPZ
1115
include("ConvenienceFunctions.jl")
1216

1317
scsta = PyCall.pyimport("scipy.stats")
1418

19+
################################################################################
20+
# HistData struct ##############################################################
21+
################################################################################
22+
"""
23+
A simple struct to store samples for empirical PDFs (histograms, distances etc.)
24+
25+
Functions that operate on HistData struct:
26+
- W1 (4 methods)
27+
- load! (3 methods)
28+
- plot (2 methods)
29+
- plot_raw (1 method)
30+
"""
31+
mutable struct HistData
32+
samples::Dict{Symbol, AbstractVecOrMat}
33+
end
34+
35+
HistData() = HistData(Dict())
36+
37+
"""
38+
Load samples into a HistData object under a specific key from a vector
39+
40+
Parameters:
41+
- hd: HistData; groups of samples accessed by keys
42+
- key: Symbol; key of the samples group to load samples to
43+
- samples: array-like; samples to load
44+
45+
If the `key` group already exists, the samples are appended. If, in addition,
46+
the samples were in a matrix, they are flattened first, and then the new ones
47+
are added.
48+
"""
49+
function load!(hd::HistData, key::Symbol, samples::AbstractVector)
50+
if haskey(hd.samples, key)
51+
if isa(hd.samples[key], Matrix)
52+
println(warn("load!"),
53+
"hd.samples[:", key, "] is a Matrix; using vec(hd.samples)")
54+
end
55+
hd.samples[key] = vcat(vec(hd.samples[key]), samples)
56+
else
57+
hd.samples[key] = samples
58+
end
59+
end
60+
61+
"""
62+
Load samples into a HistData object under a specific key from a matrix
63+
64+
Parameters:
65+
- hd: HistData; groups of samples accessed by keys
66+
- key: Symbol; key of the samples group to load samples to
67+
- samples: matrix-like; samples to load
68+
69+
If the `key` group already exists, the samples are appended. If, in addition,
70+
the samples were in a vector, the new ones are flattened first, and then added
71+
to the old ones. If the samples were in a matrix and row dimensions don't match,
72+
the minimum of the two dimensions is chosen, the rest is discarded.
73+
"""
74+
function load!(hd::HistData, key::Symbol, samples::AbstractMatrix)
75+
if haskey(hd.samples, key)
76+
if isa(hd.samples[key], Vector)
77+
println(warn("load!"),
78+
"hd.samples[:", key, "] is a Vector; using vec(samples)")
79+
hd.samples[key] = vcat(hd.samples[key], vec(samples))
80+
elseif size(hd.samples[key], 1) != size(samples, 1)
81+
println(warn("load!"), "sizes of hd.samples & samples don't match; ",
82+
"squishing down to match the minimum of the two")
83+
K = min(size(hd.samples[key], 1), size(samples,1))
84+
hd.samples[key] = hcat(hd.samples[key][1:K, 1:end], samples[1:K, 1:end])
85+
else
86+
hd.samples[key] = hcat(hd.samples[key], samples)
87+
end
88+
else
89+
hd.samples[key] = samples
90+
end
91+
end
92+
93+
"""
94+
Load samples into a HistData object under a specific key from a file
95+
96+
Parameters:
97+
- hd: HistData; groups of samples accessed by keys
98+
- key: Symbol; key of the samples group to load samples to
99+
- filename: String; name of a .npy file with samples (vector/matrix)
100+
101+
All the rules about vector/matrix interaction from two other methods apply.
102+
"""
103+
function load!(hd::HistData, key::Symbol, filename::String)
104+
samples = NPZ.npzread(filename)
105+
if isa(samples, Array) && ndims(samples) <= 2
106+
load!(hd, key, samples)
107+
else
108+
throw(error("load!: ", filename, " is not a 1- or 2-d Array; abort"))
109+
end
110+
end
111+
112+
"""
113+
Plot a histogram of a range of data by a specific key
114+
115+
Parameters:
116+
- hd: HistData; groups of samples accessed by keys
117+
- plt: a module used for plotting (only PyPlot supported)
118+
- key: Symbol; key of the samples group to construct histogram from
119+
- k: Int or UnitRange; if samples are in a matrix, which rows to use
120+
- kws: dictionary-like; keyword arguments to pass to plotting function
121+
122+
If `hd.samples[key]` is a matrix, all the samples from `k` row(s) are combined;
123+
if it is a vector, `k` is ignored.
124+
"""
125+
function plot(hd::HistData, plt, key::Symbol, k::Union{Int,UnitRange}; kws...)
126+
if length(hd.samples) == 0
127+
println(warn("plot"), "no samples, nothing to plot")
128+
return
129+
end
130+
is_pyplot = (Symbol(plt) == :PyPlot)
131+
if !is_pyplot
132+
println(warn("plot"), "only PyPlot is supported; not plotting")
133+
return
134+
end
135+
136+
if isa(hd.samples[key], Vector)
137+
S = hd.samples[key]
138+
else
139+
S = vec(hd.samples[key][k, 1:end])
140+
end
141+
142+
plot_raw(S, plt; kws...)
143+
end
144+
145+
"""
146+
Plot a histogram of whole data by a specific key
147+
148+
Parameters:
149+
- hd: HistData; groups of samples accessed by keys
150+
- plt: a module used for plotting (only PyPlot supported)
151+
- key: Symbol; key of the samples group to construct histogram from
152+
- kws: dictionary-like; keyword arguments to pass to plotting function
153+
"""
154+
plot(hd::HistData, plt, key::Symbol; kws...) =
155+
plot(hd, plt, key, UnitRange(1, size(hd.samples[key], 1)); kws...)
156+
157+
"""
158+
Plot a histogram of samples
159+
160+
Parameters:
161+
- S: array-like; samples to construct histogram from
162+
- plt: a module used for plotting (only PyPlot supported)
163+
- kws: dictionary-like; keyword arguments to pass to plotting function
164+
165+
The keyword arguments `kws` have precedence over defaults, but if left
166+
unspecified, these are the defaults:
167+
bins = "auto"
168+
histtype = "step"
169+
density = true
170+
"""
171+
function plot_raw(S::AbstractVector, plt; kws...)
172+
kwargs_local = Dict{Symbol, Any}()
173+
kwargs_local[:bins] = "auto"
174+
kwargs_local[:histtype] = "step"
175+
kwargs_local[:density] = true
176+
177+
# the order of merge is important! kws have higher priority
178+
plt.hist(S; merge(kwargs_local, kws)...)
179+
end
180+
15181
################################################################################
16182
# distance functions ###########################################################
17183
################################################################################
@@ -27,11 +193,13 @@ Returns:
27193
- w1_uv: number; the Wasserstein-1 distance
28194
"""
29195
function W1(u_samples::AbstractVector, v_samples::AbstractVector;
30-
normalize = true)
31-
L = maximum([u_samples; v_samples]) - minimum([u_samples; v_samples])
196+
normalize = false)
32197
return if !normalize
33198
scsta.wasserstein_distance(u_samples, v_samples)
34199
else
200+
u_m, u_M = extrema(u_samples)
201+
v_m, v_M = extrema(v_samples)
202+
L = max(u_M, v_M) - min(u_m, v_m)
35203
scsta.wasserstein_distance(u_samples, v_samples) / L
36204
end
37205
end
@@ -45,22 +213,22 @@ Parameters:
45213
- V_samples: matrix-like; samples from distributions (v1, v2, ...)
46214
- normalize: boolean; whether to normalize the distances by 1/(max-min)
47215
48-
`U_samples` and `V_samples` should have samples in the 2nd dimension (along
49-
rows) and have the same 1st dimension (same number of rows). If not, the minimum
50-
of the two (minimum number of rows) will be taken.
51-
52-
`normalize` induces *pairwise* normalization, i.e. it max's and min's are
53-
computed for each pair (u_j, v_j) individually.
54-
55216
Returns:
56217
- w1_UV: array-like; the pairwise Wasserstein-1 distances:
57218
w1(u1, v1)
58219
w1(u2, v2)
59220
...
60221
w1(u_K, v_K)
222+
223+
`U_samples` and `V_samples` should have samples in the 2nd dimension (along
224+
rows) and have the same 1st dimension (same number of rows). If not, the minimum
225+
of the two (minimum number of rows) will be taken.
226+
227+
`normalize` induces *pairwise* normalization, i.e. it max's and min's are
228+
computed for each pair (u_j, v_j) individually.
61229
"""
62230
function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
63-
normalize = true)
231+
normalize = false)
64232
if size(U_samples, 1) != size(V_samples, 1)
65233
println(warn("W1"), "sizes of U_samples & V_samples don't match; ",
66234
"will use the minimum of the two")
@@ -75,6 +243,143 @@ function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
75243
return w1_UV
76244
end
77245

246+
W1(U_samples::AbstractMatrix, v_samples::AbstractVector; normalize = false) =
247+
W1(vec(U_samples), v_samples; normalize = normalize)
248+
249+
W1(u_samples::AbstractVector, V_samples::AbstractMatrix; normalize = false) =
250+
W1(u_samples, vec(V_samples); normalize = normalize)
251+
252+
"""
253+
Compute pairs of Wasserstein-1 distances between `key` samples and the rest
254+
255+
Parameters:
256+
- hd: HistData; groups of samples accessed by keys
257+
- key: Symbol; key of the samples group to compare everything else against
258+
- k: Int or UnitRange; if samples are in a matrix, which rows to use
259+
260+
Returns:
261+
- key2all_combined: Dict{Symbol, Float64}; pairs of W1 distances
262+
263+
Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
264+
If any of the `hd.samples` is a matrix (not a vector) then `k` is used to access
265+
rows of said matrix, and then samples from these rows are combined together.
266+
For any of the `hd.samples` that is a vector, `k` is ignored.
267+
268+
This function is useful when you have one reference (empirical) distribution and
269+
want to compare the rest against that "ground truth" distribution.
270+
271+
Examples:
272+
k2a_row1 = W1(hd, :dns, 1)
273+
K = size(hd.samples[:dns], 1)
274+
k2a_combined = W1(hd, :dns, 1:K)
275+
println(k2a_combined[:bal])
276+
"""
277+
function W1(hd::HistData, key::Symbol, k::Union{Int,UnitRange})
278+
key2all_combined = Dict{Symbol, Float64}()
279+
for ki in keys(hd.samples)
280+
if ki == key
281+
continue
282+
end
283+
key2all_combined[ki] = W1(hd, ki, key, k)
284+
end
285+
return key2all_combined
286+
end
287+
288+
"""
289+
Compute a pair of Wasserstein-1 distance between `key1` & `key2` samples
290+
291+
Parameters:
292+
- hd: HistData; groups of samples accessed by keys
293+
- key1: Symbol; key of the first samples group
294+
- key2: Symbol; key of the second samples group
295+
- k: Int or UnitRange; if samples are in a matrix, which rows to use
296+
297+
Returns:
298+
- w1_key1key2: Float64; W1 distance
299+
300+
Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
301+
If either of them is a matrix (not a vector) then `k` is used to access rows of
302+
said matrix, and then samples from these rows are combined together.
303+
For vectors, `k` is ignored.
304+
305+
Examples:
306+
w1_dns_bal = W1(hd, :dns, :bal, 1)
307+
K = size(hd.samples[:dns], 1)
308+
w1_dns_bal_combined = W1(hd, :dns, :bal, 1:K)
309+
"""
310+
function W1(hd::HistData, key1::Symbol, key2::Symbol, k::Union{Int,UnitRange})
311+
u = if isa(hd.samples[key1], Vector)
312+
hd.samples[key1]
313+
else
314+
vec(hd.samples[key1][k, 1:end])
315+
end
316+
v = if isa(hd.samples[key2], Vector)
317+
hd.samples[key2]
318+
else
319+
vec(hd.samples[key2][k, 1:end])
320+
end
321+
return W1(u, v)
322+
end
323+
324+
"""
325+
Compute vectors of Wasserstein-1 distances between `key` samples and the rest
326+
327+
Parameters:
328+
- hd: HistData; groups of samples accessed by keys
329+
- key: Symbol; key of the samples group to compare everything else against
330+
331+
Returns:
332+
- key2all_vectorized: Dict{Symbol, Union{Vector{Float64}, Float64}};
333+
either vectors or pairs of W1 distances
334+
335+
Compute the W1-distances between `hd.samples[key]` and all other `hd.samples`.
336+
For each pair of samples (`key` and something else) where both groups of samples
337+
are in a matrix, the returned value will be a vector (corresponding to rows of
338+
the matrices); for each pair where at least one of the groups is a vector, the
339+
returned value will be a Float64, and all samples from a matrix are combined.
340+
341+
This function is useful when you have one reference (empirical) distribution and
342+
want to compare the rest against that "ground truth" distribution.
343+
344+
Examples:
345+
k2a_vectorized = W1(hd, :dns)
346+
println(k2a_vectorized[:onl] .> 0.01)
347+
"""
348+
function W1(hd::HistData, key::Symbol)
349+
key2all_vectorized = Dict{Symbol, Union{Vector{Float64}, Float64}}()
350+
for ki in keys(hd.samples)
351+
if ki == key
352+
continue
353+
end
354+
key2all_vectorized[ki] = W1(hd, ki, key)
355+
end
356+
return key2all_vectorized
357+
end
358+
359+
"""
360+
Compute a vector of Wasserstein-1 distances between `key1` & `key2` samples
361+
362+
Parameters:
363+
- hd: HistData; groups of samples accessed by keys
364+
- key1: Symbol; key of the first samples group
365+
- key2: Symbol; key of the second samples group
366+
367+
Returns:
368+
- w1_key1key2: Union{Vector{Float64}, Float64}; W1 distance
369+
370+
Compute the W1-distance between `hd.samples[key1]` and `hd.samples[key2]`.
371+
372+
If both are matrices, the returned value will be a vector (corresponding to rows
373+
of the matrices); if at least one of them is a vector, the returned value will
374+
be a Float64, and all samples from a matrix are combined.
375+
376+
Examples:
377+
w1_dns_bal = W1(hd, :dns, :bal)
378+
println(w1_dns_bal .> 0.01)
379+
"""
380+
W1(hd::HistData, key1::Symbol, key2::Symbol) =
381+
W1(hd.samples[key1], hd.samples[key2])
382+
78383
end # module
79384

80385

test/Histograms/data/x2_bal.npy

781 KB
Binary file not shown.

test/Histograms/data/x2_onl.npy

1.14 MB
Binary file not shown.

0 commit comments

Comments
 (0)