Skip to content

Commit a78db51

Browse files
committed
elementwise_loglikelihoods improvement for multiple chains (#171)
In hindsight, I'm not super-happy with how we handle multiple chains. Right now, the structure will be completely flattened. This PR makes it so that the returned dict is `vn => Vector{Vector{Float64}}` rather than `Vector{Float64}`. Thoughts?
1 parent a4f6706 commit a78db51

File tree

2 files changed

+90
-25
lines changed

2 files changed

+90
-25
lines changed

src/loglikelihoods.jl

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,42 @@ function Base.push!(
2929
ctx.loglikelihoods[vn] = logp
3030
end
3131

32+
function Base.push!(
33+
ctx::ElementwiseLikelihoodContext{Dict{String, Vector{Float64}}},
34+
vn::VarName,
35+
logp::Real
36+
)
37+
lookup = ctx.loglikelihoods
38+
= get!(lookup, string(vn), Float64[])
39+
push!(ℓ, logp)
40+
end
41+
42+
function Base.push!(
43+
ctx::ElementwiseLikelihoodContext{Dict{String, Float64}},
44+
vn::VarName,
45+
logp::Real
46+
)
47+
ctx.loglikelihoods[string(vn)] = logp
48+
end
49+
50+
function Base.push!(
51+
ctx::ElementwiseLikelihoodContext{Dict{String, Vector{Float64}}},
52+
vn::String,
53+
logp::Real
54+
)
55+
lookup = ctx.loglikelihoods
56+
= get!(lookup, vn, Float64[])
57+
push!(ℓ, logp)
58+
end
59+
60+
function Base.push!(
61+
ctx::ElementwiseLikelihoodContext{Dict{String, Float64}},
62+
vn::String,
63+
logp::Real
64+
)
65+
ctx.loglikelihoods[vn] = logp
66+
end
67+
3268

3369
function tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, vn, inds, vi)
3470
return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi)
@@ -56,11 +92,14 @@ end
5692

5793

5894
"""
59-
elementwise_loglikelihoods(model::Model, chain::Chains)
95+
elementwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
6096
61-
Runs `model` on each sample in `chain` returning an array of arrays with
62-
the i-th element inner arrays corresponding to the the likelihood of the i-th
63-
observation for that particular sample in `chain`.
97+
Runs `model` on each sample in `chain` returning a `Dict{String, Matrix{Float64}}`
98+
with keys corresponding to symbols of the observations, and values being matrices
99+
of shape `(num_chains, num_samples)`.
100+
101+
`keytype` specifies what the type of the keys used in the returned `Dict` are.
102+
Currently, only `String` and `VarName` are supported.
64103
65104
# Notes
66105
Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
@@ -75,10 +114,10 @@ or
75114
```julia
76115
y ~ MvNormal(fill(μ, n), fill(σ, n))
77116
```
78-
Unfortunately, just by looking at the latter statement, it's impossible to tell whether or
79-
not this is one *single* observation which is `n` dimensional OR if we have *multiple*
80-
1-dimensional observations. Therefore, `loglikelihoods` will only work with the first
81-
example.
117+
Unfortunately, just by looking at the latter statement, it's impossible to tell
118+
whether or not this is one *single* observation which is `n` dimensional OR if we
119+
have *multiple* 1-dimensional observations. Therefore, `loglikelihoods` will only
120+
work with the first example.
82121
83122
# Examples
84123
```julia-repl
@@ -99,19 +138,37 @@ julia> model = demo(randn(3), randn());
99138
100139
julia> chain = sample(model, MH(), 10);
101140
102-
julia> DynamicPPL.elementwise_loglikelihoods(model, chain)
103-
Dict{String,Array{Float64,1}} with 4 entries:
104-
"xs[3]" => [-1.02616, -1.26931, -1.05003, -5.05458, -1.33825, -1.02904, -1.23761, -1.30128, -1.04872, -2.03716]
105-
"xs[1]" => [-2.08205, -2.51387, -3.03175, -2.5981, -2.31322, -2.62284, -2.70874, -1.18617, -1.36281, -4.39839]
106-
"xs[2]" => [-2.20604, -2.63495, -3.22802, -2.48785, -2.40941, -2.78791, -2.85013, -1.24081, -1.46019, -4.59025]
107-
"y" => [-1.36627, -1.21964, -1.03342, -7.46617, -1.3234, -1.14536, -1.14781, -2.48912, -2.23705, -1.26267]
141+
julia> elementwise_loglikelihoods(model, chain)
142+
Dict{String,Array{Float64,2}} with 4 entries:
143+
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
144+
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
145+
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
146+
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
147+
148+
julia> elementwise_loglikelihoods(model, chain, String)
149+
Dict{String,Array{Float64,2}} with 4 entries:
150+
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
151+
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
152+
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
153+
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
154+
155+
julia> elementwise_loglikelihoods(model, chain, VarName)
156+
Dict{VarName,Array{Float64,2}} with 4 entries:
157+
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
158+
y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
159+
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
160+
xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
108161
```
109162
"""
110-
function elementwise_loglikelihoods(model::Model, chain)
163+
function elementwise_loglikelihoods(
164+
model::Model,
165+
chain,
166+
keytype::Type{T} = String
167+
) where {T}
111168
# Get the data by executing the model once
112-
ctx = ElementwiseLikelihoodContext()
113169
spl = SampleFromPrior()
114170
vi = VarInfo(model)
171+
ctx = ElementwiseLikelihoodContext(Dict{T, Vector{Float64}}())
115172

116173
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
117174
for (sample_idx, chain_idx) in iters
@@ -121,7 +178,14 @@ function elementwise_loglikelihoods(model::Model, chain)
121178
# Execute model
122179
model(vi, spl, ctx)
123180
end
124-
return ctx.loglikelihoods
181+
182+
niters = size(chain, 1)
183+
nchains = size(chain, 3)
184+
loglikelihoods = Dict(
185+
varname => reshape(logliks, niters, nchains)
186+
for (varname, logliks) in ctx.loglikelihoods
187+
)
188+
return loglikelihoods
125189
end
126190

127191
function elementwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)

test/loglikelihoods.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@ using .Turing
1414
xs = randn(3);
1515
y = randn();
1616
model = demo(xs, y);
17-
chain = sample(model, MH(), 100);
18-
results = elementwise_loglikelihoods(model, chain)
19-
var_to_likelihoods = Dict(string(varname) => logliks for (varname, logliks) in results)
17+
chain = sample(model, MH(), MCMCThreads(), 100, 2);
18+
var_to_likelihoods = elementwise_loglikelihoods(model, chain)
2019
@test haskey(var_to_likelihoods, "xs[1]")
2120
@test haskey(var_to_likelihoods, "xs[2]")
2221
@test haskey(var_to_likelihoods, "xs[3]")
2322
@test haskey(var_to_likelihoods, "y")
2423

25-
for (i, (s, m)) in enumerate(zip(chain[:s], chain[:m]))
26-
@test logpdf(Normal(m, s), xs[1]) == var_to_likelihoods["xs[1]"][i]
27-
@test logpdf(Normal(m, s), xs[2]) == var_to_likelihoods["xs[2]"][i]
28-
@test logpdf(Normal(m, s), xs[3]) == var_to_likelihoods["xs[3]"][i]
29-
@test logpdf(Normal(m, s), y) == var_to_likelihoods["y"][i]
24+
for chain_idx in MCMCChains.chains(chain)
25+
for (i, (s, m)) in enumerate(zip(chain[:, :s, chain_idx], chain[:, :m, chain_idx]))
26+
@test logpdf(Normal(m, s), xs[1]) == var_to_likelihoods["xs[1]"][i, chain_idx]
27+
@test logpdf(Normal(m, s), xs[2]) == var_to_likelihoods["xs[2]"][i, chain_idx]
28+
@test logpdf(Normal(m, s), xs[3]) == var_to_likelihoods["xs[3]"][i, chain_idx]
29+
@test logpdf(Normal(m, s), y) == var_to_likelihoods["y"][i, chain_idx]
30+
end
3031
end
3132

3233
var_info = VarInfo(model)

0 commit comments

Comments
 (0)