@@ -8,10 +8,94 @@ Functions in this module:
8
8
"""
9
9
10
10
import PyCall
11
+ import NPZ
11
12
include (" ConvenienceFunctions.jl" )
12
13
13
14
scsta = PyCall. pyimport (" scipy.stats" )
14
15
16
+ # ###############################################################################
17
+ # HistData struct ##############################################################
18
+ # ###############################################################################
19
+ mutable struct HistData
20
+ samples:: Dict{Symbol, AbstractVecOrMat}
21
+ end
22
+
23
+ HistData () = HistData (Dict ())
24
+
25
+ function load! (hd:: HistData , key:: Symbol , samples:: AbstractVector )
26
+ if haskey (hd. samples, key)
27
+ if isa (hd. samples[key], Matrix)
28
+ println (warn (" load!" ),
29
+ " hd.samples[:" , key, " ] is a Matrix; using vec(hd.samples)" )
30
+ end
31
+ hd. samples[key] = vcat (vec (hd. samples[key]), samples)
32
+ else
33
+ hd. samples[key] = samples
34
+ end
35
+ end
36
+
37
+ function load! (hd:: HistData , key:: Symbol , samples:: AbstractMatrix )
38
+ if haskey (hd. samples, key)
39
+ if isa (hd. samples[key], Vector)
40
+ println (warn (" load!" ),
41
+ " hd.samples[:" , key, " ] is a Vector; using vec(samples)" )
42
+ hd. samples[key] = vcat (hd. samples[key], vec (samples))
43
+ elseif size (hd. samples[key], 1 ) != size (samples, 1 )
44
+ println (warn (" load!" ), " sizes of hd.samples & samples don't match; " ,
45
+ " squishing down to match the minimum of the two" )
46
+ K = min (size (hd. samples[key], 1 ), size (samples,1 ))
47
+ hd. samples[key] = hcat (hd. samples[key][1 : K, 1 : end ], samples[1 : K, 1 : end ])
48
+ else
49
+ hd. samples[key] = hcat (hd. samples[key], samples)
50
+ end
51
+ else
52
+ hd. samples[key] = samples
53
+ end
54
+ end
55
+
56
+ function load! (hd:: HistData , key:: Symbol , filename:: String )
57
+ samples = NPZ. npzread (filename)
58
+ if isa (samples, Array)
59
+ load! (hd, key, samples)
60
+ else
61
+ throw (error (" load!: " , filename, " is not an Array; abort" ))
62
+ end
63
+ end
64
+
65
+ function plot (hd:: HistData , plt, key:: Symbol , k:: Union{Int,UnitRange} ; kws... )
66
+ if length (hd. samples) == 0
67
+ println (warn (" plot" ), " no samples, nothing to plot" )
68
+ return
69
+ end
70
+ is_pyplot = (Symbol (plt) == :PyPlot )
71
+ if ! is_pyplot
72
+ println (warn (" plot" ), " only PyPlot is supported; not plotting" )
73
+ return
74
+ end
75
+
76
+ if isa (hd. samples[key], Vector)
77
+ S = hd. samples[key]
78
+ else
79
+ S = vec (hd. samples[key][k, 1 : end ])
80
+ end
81
+
82
+ plot_raw (S, plt; kws... )
83
+
84
+ end
85
+
86
+ plot (hd:: HistData , plt, key:: Symbol ; kws... ) =
87
+ plot (hd, plt, key, UnitRange (1 , size (hd. samples[key], 1 )); kws... )
88
+
89
+ function plot_raw (S:: Vector , plt; kws... )
90
+ kwargs_local = Dict {Symbol, Any} ()
91
+ kwargs_local[:bins ] = " auto"
92
+ kwargs_local[:histtype ] = " step"
93
+ kwargs_local[:density ] = true
94
+
95
+ # the order of merge is important! kws have higher priority
96
+ plt. hist (S; merge (kwargs_local, kws)... )
97
+ end
98
+
15
99
# ###############################################################################
16
100
# distance functions ###########################################################
17
101
# ###############################################################################
@@ -75,6 +159,46 @@ function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
75
159
return w1_UV
76
160
end
77
161
162
+ function W1 (hd:: HistData , key:: Symbol , k:: Union{Int,UnitRange} )
163
+ key2all_combined = Dict {Symbol, Float64} ()
164
+ for ki in keys (hd. samples)
165
+ if ki == key
166
+ continue
167
+ end
168
+ key2all_combined[ki] = W1 (hd, ki, key, k)
169
+ end
170
+ return key2all_combined
171
+ end
172
+
173
+ function W1 (hd:: HistData , key1:: Symbol , key2:: Symbol , k:: Union{Int,UnitRange} )
174
+ if isa (hd. samples[key1], Vector)
175
+ u = hd. samples[key1]
176
+ else
177
+ u = vec (hd. samples[key1][k, 1 : end ])
178
+ end
179
+ if isa (hd. samples[key2], Vector)
180
+ v = hd. samples[key2]
181
+ else
182
+ v = vec (hd. samples[key2][k, 1 : end ])
183
+ end
184
+ return W1 (u, v)
185
+ end
186
+
187
+ function W1 (hd:: HistData , key:: Symbol )
188
+ key2all_vec = Dict {Symbol, Vector{Float64}} ()
189
+ for ki in keys (hd. samples)
190
+ if ki == key
191
+ continue
192
+ end
193
+ key2all_vec[ki] = W1 (hd, ki, key)
194
+ end
195
+ return key2all_vec
196
+ end
197
+
198
+ function W1 (hd:: HistData , key1:: Symbol , key2:: Symbol )
199
+ return W1 (hd. samples[key1], hd. samples[key2])
200
+ end
201
+
78
202
end # module
79
203
80
204
0 commit comments