Skip to content

Commit 4b6d95d

Browse files
committed
Method to extract loglikelihoods (#166)
For several reasons, it would be very nice to have a way of extracting the log-likelihoods from a chain. This PR implements the method `loglikelihoods` to do exactly this. # Up for discussion 1. **Return-value.** Right now it returns a `Dict{String, Vector{Float64})` with the keys being `string(varname)` and the values being an array with the i-th index corresponding to the log-likelihood for `string(varname)` in `chain[i]`. Alternatives: - Dict of the form `Dict(y => Dict(y[1] => ..., y[2] => ...), ...)`, i.e. "hierarhical" - Dict of the form `Dict(y[1] => ..., y[2] => ..., ...)`, i.e. "flattened" - ???? 2. **Project structure.** I'm a bit uncertain where to actually put the implementation. As I now experienced, what you actually need to implement for to make a `AbstractSampler` is a bit unclear, e.g. are some methods in `varinfo.jl` which also requires implementation (e.g. `getindex`). So, should I make it it's own file, like I have now, or should I follow suit with `SampleFromPrior` and `SampleFromUniform`? # Example ```julia julia> using DynamicPPL, Turing julia> @model function demo(xs, y) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) for i in eachindex(xs) xs[i] ~ Normal(m, √s) end y ~ Normal(m, √s) end demo (generic function with 1 method) julia> model = demo(randn(3), randn()); julia> chain = sample(model, MH(), 10); julia> DynamicPPL.loglikelihoods(model, chain) Dict{String,Array{Float64,1}} with 4 entries: "xs[3]" => [-1.02616, -1.26931, -1.05003, -5.05458, -1.33825, -1.02904, -1.23761, -1.30128, -1.04872, -2.03716] "xs[1]" => [-2.08205, -2.51387, -3.03175, -2.5981, -2.31322, -2.62284, -2.70874, -1.18617, -1.36281, -4.39839] "xs[2]" => [-2.20604, -2.63495, -3.22802, -2.48785, -2.40941, -2.78791, -2.85013, -1.24081, -1.46019, -4.59025] "y" => [-1.36627, -1.21964, -1.03342, -7.46617, -1.3234, -1.14536, -1.14781, -2.48912, -2.23705, -1.26267] ```
1 parent 405546f commit 4b6d95d

File tree

5 files changed

+175
-1
lines changed

5 files changed

+175
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.9.1"
3+
version = "0.9.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ export AbstractVarInfo,
9090
# Convenience functions
9191
logprior,
9292
logjoint,
93+
elementwise_loglikelihoods,
9394
# Convenience macros
9495
@addlogprob!
9596

@@ -118,5 +119,6 @@ include("context_implementations.jl")
118119
include("compiler.jl")
119120
include("prob_macro.jl")
120121
include("compat/ad.jl")
122+
include("loglikelihoods.jl")
121123

122124
end # module

src/loglikelihoods.jl

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Context version
2+
struct ElementwiseLikelihoodContext{A, Ctx} <: AbstractContext
3+
loglikelihoods::A
4+
ctx::Ctx
5+
end
6+
7+
function ElementwiseLikelihoodContext(
8+
likelihoods = Dict{VarName, Vector{Float64}}(),
9+
ctx::AbstractContext = LikelihoodContext()
10+
)
11+
return ElementwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx)
12+
end
13+
14+
function Base.push!(
15+
ctx::ElementwiseLikelihoodContext{Dict{VarName, Vector{Float64}}},
16+
vn::VarName,
17+
logp::Real
18+
)
19+
lookup = ctx.loglikelihoods
20+
= get!(lookup, vn, Float64[])
21+
push!(ℓ, logp)
22+
end
23+
24+
function Base.push!(
25+
ctx::ElementwiseLikelihoodContext{Dict{VarName, Float64}},
26+
vn::VarName,
27+
logp::Real
28+
)
29+
ctx.loglikelihoods[vn] = logp
30+
end
31+
32+
33+
function tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, vn, inds, vi)
34+
return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi)
35+
end
36+
37+
function dot_tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, left, vn, inds, vi)
38+
value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi)
39+
acclogp!(vi, logp)
40+
return value
41+
end
42+
43+
44+
function tilde_observe(ctx::ElementwiseLikelihoodContext, sampler, right, left, vname, vinds, vi)
45+
# This is slightly unfortunate since it is not completely generic...
46+
# Ideally we would call `tilde_observe` recursively but then we don't get the
47+
# loglikelihood value.
48+
logp = tilde(ctx.ctx, sampler, right, left, vi)
49+
acclogp!(vi, logp)
50+
51+
# track loglikelihood value
52+
push!(ctx, vname, logp)
53+
54+
return left
55+
end
56+
57+
58+
"""
59+
elementwise_loglikelihoods(model::Model, chain::Chains)
60+
61+
Runs `model` on each sample in `chain` returning an array of arrays with
62+
the i-th element inner arrays corresponding to the the likelihood of the i-th
63+
observation for that particular sample in `chain`.
64+
65+
# Notes
66+
Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
67+
both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
68+
*observation*) statements can be implemented in two ways:
69+
```julia
70+
for i in eachindex(y)
71+
y[i] ~ Normal(μ, σ)
72+
end
73+
```
74+
or
75+
```julia
76+
y ~ MvNormal(fill(μ, n), fill(σ, n))
77+
```
78+
Unfortunately, just by looking at the latter statement, it's impossible to tell whether or
79+
not this is one *single* observation which is `n` dimensional OR if we have *multiple*
80+
1-dimensional observations. Therefore, `loglikelihoods` will only work with the first
81+
example.
82+
83+
# Examples
84+
```julia-repl
85+
julia> using DynamicPPL, Turing
86+
87+
julia> @model function demo(xs, y)
88+
s ~ InverseGamma(2, 3)
89+
m ~ Normal(0, √s)
90+
for i in eachindex(xs)
91+
xs[i] ~ Normal(m, √s)
92+
end
93+
94+
y ~ Normal(m, √s)
95+
end
96+
demo (generic function with 1 method)
97+
98+
julia> model = demo(randn(3), randn());
99+
100+
julia> chain = sample(model, MH(), 10);
101+
102+
julia> DynamicPPL.elementwise_loglikelihoods(model, chain)
103+
Dict{String,Array{Float64,1}} with 4 entries:
104+
"xs[3]" => [-1.02616, -1.26931, -1.05003, -5.05458, -1.33825, -1.02904, -1.23761, -1.30128, -1.04872, -2.03716]
105+
"xs[1]" => [-2.08205, -2.51387, -3.03175, -2.5981, -2.31322, -2.62284, -2.70874, -1.18617, -1.36281, -4.39839]
106+
"xs[2]" => [-2.20604, -2.63495, -3.22802, -2.48785, -2.40941, -2.78791, -2.85013, -1.24081, -1.46019, -4.59025]
107+
"y" => [-1.36627, -1.21964, -1.03342, -7.46617, -1.3234, -1.14536, -1.14781, -2.48912, -2.23705, -1.26267]
108+
```
109+
"""
110+
function elementwise_loglikelihoods(model::Model, chain)
111+
# Get the data by executing the model once
112+
ctx = ElementwiseLikelihoodContext()
113+
spl = SampleFromPrior()
114+
vi = VarInfo(model)
115+
116+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
117+
for (sample_idx, chain_idx) in iters
118+
# Update the values
119+
setval!(vi, chain, sample_idx, chain_idx)
120+
121+
# Execute model
122+
model(vi, spl, ctx)
123+
end
124+
return ctx.loglikelihoods
125+
end
126+
127+
function elementwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
128+
ctx = ElementwiseLikelihoodContext(Dict{VarName, Float64}())
129+
model(varinfo, SampleFromPrior(), ctx)
130+
return ctx.loglikelihoods
131+
end

test/loglikelihoods.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using .Turing
2+
3+
@testset "loglikelihoods" begin
4+
@model function demo(xs, y)
5+
s ~ InverseGamma(2, 3)
6+
m ~ Normal(0, s)
7+
for i in eachindex(xs)
8+
xs[i] ~ Normal(m, s)
9+
end
10+
11+
y ~ Normal(m, s)
12+
end
13+
14+
xs = randn(3);
15+
y = randn();
16+
model = demo(xs, y);
17+
chain = sample(model, MH(), 100);
18+
results = elementwise_loglikelihoods(model, chain)
19+
var_to_likelihoods = Dict(string(varname) => logliks for (varname, logliks) in results)
20+
@test haskey(var_to_likelihoods, "xs[1]")
21+
@test haskey(var_to_likelihoods, "xs[2]")
22+
@test haskey(var_to_likelihoods, "xs[3]")
23+
@test haskey(var_to_likelihoods, "y")
24+
25+
for (i, (s, m)) in enumerate(zip(chain[:s], chain[:m]))
26+
@test logpdf(Normal(m, s), xs[1]) == var_to_likelihoods["xs[1]"][i]
27+
@test logpdf(Normal(m, s), xs[2]) == var_to_likelihoods["xs[2]"][i]
28+
@test logpdf(Normal(m, s), xs[3]) == var_to_likelihoods["xs[3]"][i]
29+
@test logpdf(Normal(m, s), y) == var_to_likelihoods["y"][i]
30+
end
31+
32+
var_info = VarInfo(model)
33+
results = DynamicPPL.elementwise_loglikelihoods(model, var_info)
34+
var_to_likelihoods = Dict(string(vn) =>for (vn, ℓ) in results)
35+
s, m = var_info[SampleFromPrior()]
36+
@test logpdf(Normal(m, s), xs[1]) == var_to_likelihoods["xs[1]"]
37+
@test logpdf(Normal(m, s), xs[2]) == var_to_likelihoods["xs[2]"]
38+
@test logpdf(Normal(m, s), xs[3]) == var_to_likelihoods["xs[3]"]
39+
@test logpdf(Normal(m, s), y) == var_to_likelihoods["y"]
40+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include("test_util.jl")
2828
include("independence.jl")
2929
include("distribution_wrappers.jl")
3030
include("context_implementations.jl")
31+
include("loglikelihoods.jl")
3132

3233
include("threadsafe.jl")
3334

0 commit comments

Comments
 (0)