@@ -5,7 +5,7 @@ struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext
5
5
end
6
6
7
7
function PointwiseLikelihoodContext (
8
- likelihoods= Dict {VarName,Vector{Float64}} (),
8
+ likelihoods= OrderedDict {VarName,Vector{Float64}} (),
9
9
context:: AbstractContext = LikelihoodContext (),
10
10
)
11
11
return PointwiseLikelihoodContext {typeof(likelihoods),typeof(context)} (
@@ -20,7 +20,7 @@ function setchildcontext(context::PointwiseLikelihoodContext, child)
20
20
end
21
21
22
22
function Base. push! (
23
- context:: PointwiseLikelihoodContext{Dict {VarName,Vector{Float64}}} ,
23
+ context:: PointwiseLikelihoodContext{<:AbstractDict {VarName,Vector{Float64}}} ,
24
24
vn:: VarName ,
25
25
logp:: Real ,
26
26
)
@@ -30,13 +30,15 @@ function Base.push!(
30
30
end
31
31
32
32
function Base. push! (
33
- context:: PointwiseLikelihoodContext{Dict{VarName,Float64}} , vn:: VarName , logp:: Real
33
+ context:: PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}} ,
34
+ vn:: VarName ,
35
+ logp:: Real ,
34
36
)
35
37
return context. loglikelihoods[vn] = logp
36
38
end
37
39
38
40
function Base. push! (
39
- context:: PointwiseLikelihoodContext{Dict {String,Vector{Float64}}} ,
41
+ context:: PointwiseLikelihoodContext{<:AbstractDict {String,Vector{Float64}}} ,
40
42
vn:: VarName ,
41
43
logp:: Real ,
42
44
)
@@ -46,13 +48,15 @@ function Base.push!(
46
48
end
47
49
48
50
function Base. push! (
49
- context:: PointwiseLikelihoodContext{Dict{String,Float64}} , vn:: VarName , logp:: Real
51
+ context:: PointwiseLikelihoodContext{<:AbstractDict{String,Float64}} ,
52
+ vn:: VarName ,
53
+ logp:: Real ,
50
54
)
51
55
return context. loglikelihoods[string (vn)] = logp
52
56
end
53
57
54
58
function Base. push! (
55
- context:: PointwiseLikelihoodContext{Dict {String,Vector{Float64}}} ,
59
+ context:: PointwiseLikelihoodContext{<:AbstractDict {String,Vector{Float64}}} ,
56
60
vn:: String ,
57
61
logp:: Real ,
58
62
)
@@ -62,7 +66,9 @@ function Base.push!(
62
66
end
63
67
64
68
function Base. push! (
65
- context:: PointwiseLikelihoodContext{Dict{String,Float64}} , vn:: String , logp:: Real
69
+ context:: PointwiseLikelihoodContext{<:AbstractDict{String,Float64}} ,
70
+ vn:: String ,
71
+ logp:: Real ,
66
72
)
67
73
return context. loglikelihoods[vn] = logp
68
74
end
@@ -126,11 +132,11 @@ end
126
132
"""
127
133
pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
128
134
129
- Runs `model` on each sample in `chain` returning a `Dict {String, Matrix{Float64}}`
135
+ Runs `model` on each sample in `chain` returning a `OrderedDict {String, Matrix{Float64}}`
130
136
with keys corresponding to symbols of the observations, and values being matrices
131
137
of shape `(num_chains, num_samples)`.
132
138
133
- `keytype` specifies what the type of the keys used in the returned `Dict ` are.
139
+ `keytype` specifies what the type of the keys used in the returned `OrderedDict ` are.
134
140
Currently, only `String` and `VarName` are supported.
135
141
136
142
# Notes
@@ -179,25 +185,25 @@ julia> model = demo(randn(3), randn());
179
185
julia> chain = sample(model, MH(), 10);
180
186
181
187
julia> pointwise_loglikelihoods(model, chain)
182
- Dict{String,Array{Float64,2}} with 4 entries:
183
- "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
188
+ OrderedDict{String,Array{Float64,2}} with 4 entries:
184
189
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
185
190
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
191
+ "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
186
192
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
187
193
188
194
julia> pointwise_loglikelihoods(model, chain, String)
189
- Dict{String,Array{Float64,2}} with 4 entries:
190
- "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
195
+ OrderedDict{String,Array{Float64,2}} with 4 entries:
191
196
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
192
197
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
198
+ "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
193
199
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
194
200
195
201
julia> pointwise_loglikelihoods(model, chain, VarName)
196
- Dict{VarName,Array{Float64,2}} with 4 entries:
197
- xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
198
- y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
202
+ OrderedDict{VarName,Array{Float64,2}} with 4 entries:
199
203
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
204
+ xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
200
205
xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
206
+ y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
201
207
```
202
208
203
209
## Broadcasting
@@ -224,7 +230,7 @@ julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])
224
230
function pointwise_loglikelihoods (model:: Model , chain, keytype:: Type{T} = String) where {T}
225
231
# Get the data by executing the model once
226
232
vi = VarInfo (model)
227
- context = PointwiseLikelihoodContext (Dict {T,Vector{Float64}} ())
233
+ context = PointwiseLikelihoodContext (OrderedDict {T,Vector{Float64}} ())
228
234
229
235
iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
230
236
for (sample_idx, chain_idx) in iters
@@ -237,15 +243,15 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String)
237
243
238
244
niters = size (chain, 1 )
239
245
nchains = size (chain, 3 )
240
- loglikelihoods = Dict (
246
+ loglikelihoods = OrderedDict (
241
247
varname => reshape (logliks, niters, nchains) for
242
248
(varname, logliks) in context. loglikelihoods
243
249
)
244
250
return loglikelihoods
245
251
end
246
252
247
253
function pointwise_loglikelihoods (model:: Model , varinfo:: AbstractVarInfo )
248
- context = PointwiseLikelihoodContext (Dict {VarName,Vector{Float64}} ())
254
+ context = PointwiseLikelihoodContext (OrderedDict {VarName,Vector{Float64}} ())
249
255
model (varinfo, context)
250
256
return context. loglikelihoods
251
257
end
0 commit comments