Skip to content

Commit a393f23

Browse files
committed
Implement HistData struct & supporting funcs
1 parent e622818 commit a393f23

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

src/Histograms.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,94 @@ Functions in this module:
88
"""
99

1010
import PyCall
11+
import NPZ
1112
include("ConvenienceFunctions.jl")
1213

1314
scsta = PyCall.pyimport("scipy.stats")
1415

16+
################################################################################
17+
# HistData struct ##############################################################
18+
################################################################################
19+
mutable struct HistData
20+
samples::Dict{Symbol, AbstractVecOrMat}
21+
end
22+
23+
HistData() = HistData(Dict())
24+
25+
function load!(hd::HistData, key::Symbol, samples::AbstractVector)
26+
if haskey(hd.samples, key)
27+
if isa(hd.samples[key], Matrix)
28+
println(warn("load!"),
29+
"hd.samples[:", key, "] is a Matrix; using vec(hd.samples)")
30+
end
31+
hd.samples[key] = vcat(vec(hd.samples[key]), samples)
32+
else
33+
hd.samples[key] = samples
34+
end
35+
end
36+
37+
function load!(hd::HistData, key::Symbol, samples::AbstractMatrix)
38+
if haskey(hd.samples, key)
39+
if isa(hd.samples[key], Vector)
40+
println(warn("load!"),
41+
"hd.samples[:", key, "] is a Vector; using vec(samples)")
42+
hd.samples[key] = vcat(hd.samples[key], vec(samples))
43+
elseif size(hd.samples[key], 1) != size(samples, 1)
44+
println(warn("load!"), "sizes of hd.samples & samples don't match; ",
45+
"squishing down to match the minimum of the two")
46+
K = min(size(hd.samples[key], 1), size(samples,1))
47+
hd.samples[key] = hcat(hd.samples[key][1:K, 1:end], samples[1:K, 1:end])
48+
else
49+
hd.samples[key] = hcat(hd.samples[key], samples)
50+
end
51+
else
52+
hd.samples[key] = samples
53+
end
54+
end
55+
56+
function load!(hd::HistData, key::Symbol, filename::String)
57+
samples = NPZ.npzread(filename)
58+
if isa(samples, Array)
59+
load!(hd, key, samples)
60+
else
61+
throw(error("load!: ", filename, " is not an Array; abort"))
62+
end
63+
end
64+
65+
function plot(hd::HistData, plt, key::Symbol, k::Union{Int,UnitRange}; kws...)
66+
if length(hd.samples) == 0
67+
println(warn("plot"), "no samples, nothing to plot")
68+
return
69+
end
70+
is_pyplot = (Symbol(plt) == :PyPlot)
71+
if !is_pyplot
72+
println(warn("plot"), "only PyPlot is supported; not plotting")
73+
return
74+
end
75+
76+
if isa(hd.samples[key], Vector)
77+
S = hd.samples[key]
78+
else
79+
S = vec(hd.samples[key][k, 1:end])
80+
end
81+
82+
plot_raw(S, plt; kws...)
83+
84+
end
85+
86+
plot(hd::HistData, plt, key::Symbol; kws...) =
87+
plot(hd, plt, key, UnitRange(1, size(hd.samples[key], 1)); kws...)
88+
89+
function plot_raw(S::Vector, plt; kws...)
90+
kwargs_local = Dict{Symbol, Any}()
91+
kwargs_local[:bins] = "auto"
92+
kwargs_local[:histtype] = "step"
93+
kwargs_local[:density] = true
94+
95+
# the order of merge is important! kws have higher priority
96+
plt.hist(S; merge(kwargs_local, kws)...)
97+
end
98+
1599
################################################################################
16100
# distance functions ###########################################################
17101
################################################################################
@@ -75,6 +159,46 @@ function W1(U_samples::AbstractMatrix, V_samples::AbstractMatrix;
75159
return w1_UV
76160
end
77161

162+
function W1(hd::HistData, key::Symbol, k::Union{Int,UnitRange})
163+
key2all_combined = Dict{Symbol, Float64}()
164+
for ki in keys(hd.samples)
165+
if ki == key
166+
continue
167+
end
168+
key2all_combined[ki] = W1(hd, ki, key, k)
169+
end
170+
return key2all_combined
171+
end
172+
173+
function W1(hd::HistData, key1::Symbol, key2::Symbol, k::Union{Int,UnitRange})
174+
if isa(hd.samples[key1], Vector)
175+
u = hd.samples[key1]
176+
else
177+
u = vec(hd.samples[key1][k, 1:end])
178+
end
179+
if isa(hd.samples[key2], Vector)
180+
v = hd.samples[key2]
181+
else
182+
v = vec(hd.samples[key2][k, 1:end])
183+
end
184+
return W1(u, v)
185+
end
186+
187+
function W1(hd::HistData, key::Symbol)
188+
key2all_vec = Dict{Symbol, Vector{Float64}}()
189+
for ki in keys(hd.samples)
190+
if ki == key
191+
continue
192+
end
193+
key2all_vec[ki] = W1(hd, ki, key)
194+
end
195+
return key2all_vec
196+
end
197+
198+
function W1(hd::HistData, key1::Symbol, key2::Symbol)
199+
return W1(hd.samples[key1], hd.samples[key2])
200+
end
201+
78202
end # module
79203

80204

0 commit comments

Comments
 (0)