Skip to content

Commit 64089cf

Browse files
Use OrderedDict for pointwise_loglikelihoods instead of Dict (#475)
* use ordered dict for pointwise_loglikelihoods instead of Dict * bump patch version * updated docstring * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 0ffa0e5 commit 64089cf

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.22.3"
3+
version = "0.22.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/loglikelihoods.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext
55
end
66

77
function PointwiseLikelihoodContext(
8-
likelihoods=Dict{VarName,Vector{Float64}}(),
8+
likelihoods=OrderedDict{VarName,Vector{Float64}}(),
99
context::AbstractContext=LikelihoodContext(),
1010
)
1111
return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}(
@@ -20,7 +20,7 @@ function setchildcontext(context::PointwiseLikelihoodContext, child)
2020
end
2121

2222
function Base.push!(
23-
context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}},
23+
context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}},
2424
vn::VarName,
2525
logp::Real,
2626
)
@@ -30,13 +30,15 @@ function Base.push!(
3030
end
3131

3232
function Base.push!(
33-
context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real
33+
context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}},
34+
vn::VarName,
35+
logp::Real,
3436
)
3537
return context.loglikelihoods[vn] = logp
3638
end
3739

3840
function Base.push!(
39-
context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}},
41+
context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}},
4042
vn::VarName,
4143
logp::Real,
4244
)
@@ -46,13 +48,15 @@ function Base.push!(
4648
end
4749

4850
function Base.push!(
49-
context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real
51+
context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}},
52+
vn::VarName,
53+
logp::Real,
5054
)
5155
return context.loglikelihoods[string(vn)] = logp
5256
end
5357

5458
function Base.push!(
55-
context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}},
59+
context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}},
5660
vn::String,
5761
logp::Real,
5862
)
@@ -62,7 +66,9 @@ function Base.push!(
6266
end
6367

6468
function Base.push!(
65-
context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real
69+
context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}},
70+
vn::String,
71+
logp::Real,
6672
)
6773
return context.loglikelihoods[vn] = logp
6874
end
@@ -126,11 +132,11 @@ end
126132
"""
127133
pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
128134
129-
Runs `model` on each sample in `chain` returning a `Dict{String, Matrix{Float64}}`
135+
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
130136
with keys corresponding to symbols of the observations, and values being matrices
131137
of shape `(num_chains, num_samples)`.
132138
133-
`keytype` specifies what the type of the keys used in the returned `Dict` are.
139+
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
134140
Currently, only `String` and `VarName` are supported.
135141
136142
# Notes
@@ -179,25 +185,25 @@ julia> model = demo(randn(3), randn());
179185
julia> chain = sample(model, MH(), 10);
180186
181187
julia> pointwise_loglikelihoods(model, chain)
182-
Dict{String,Array{Float64,2}} with 4 entries:
183-
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
188+
OrderedDict{String,Array{Float64,2}} with 4 entries:
184189
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
185190
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
191+
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
186192
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
187193
188194
julia> pointwise_loglikelihoods(model, chain, String)
189-
Dict{String,Array{Float64,2}} with 4 entries:
190-
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
195+
OrderedDict{String,Array{Float64,2}} with 4 entries:
191196
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
192197
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
198+
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
193199
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
194200
195201
julia> pointwise_loglikelihoods(model, chain, VarName)
196-
Dict{VarName,Array{Float64,2}} with 4 entries:
197-
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
198-
y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
202+
OrderedDict{VarName,Array{Float64,2}} with 4 entries:
199203
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
204+
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
200205
xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
206+
y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
201207
```
202208
203209
## Broadcasting
@@ -224,7 +230,7 @@ julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])
224230
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T}
225231
# Get the data by executing the model once
226232
vi = VarInfo(model)
227-
context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}())
233+
context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}())
228234

229235
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
230236
for (sample_idx, chain_idx) in iters
@@ -237,15 +243,15 @@ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String)
237243

238244
niters = size(chain, 1)
239245
nchains = size(chain, 3)
240-
loglikelihoods = Dict(
246+
loglikelihoods = OrderedDict(
241247
varname => reshape(logliks, niters, nchains) for
242248
(varname, logliks) in context.loglikelihoods
243249
)
244250
return loglikelihoods
245251
end
246252

247253
function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
248-
context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}())
254+
context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}())
249255
model(varinfo, context)
250256
return context.loglikelihoods
251257
end

0 commit comments

Comments
 (0)