Skip to content

Commit 64357e1

Browse files
committed
Fix for pointwise_loglikelihoods (#281)
Currently statements such as `x .~ Normal()` will result in only a single entry in the result of `pointwise_loglikelihoods`, i.e. `x` is treated as a single multivariate random variable rather than a collection of independent random variables. This is unfortunate for a couple of reasons. a) It is counter-intuitive as indicated by users finding it confusing: TuringLang/Turing.jl#1666. And I 100% agree with them, in particular because of (b). b) It is actually different from how `x` is treated in `dot_tilde_assume` due to the usage of `DynamicPPL.unwrap_right_left_vns` for the assume-branch but _not_ for the observe-branch https://github.com/TuringLang/DynamicPPL.jl/blob/b82459a081c4b8925da3c0d97a6dc61687648ed3/src/compiler.jl#L369-L387 We _could_ simply add the `unwrap_right_left_vns` to the observe-branch too, _but_ it will add some unnecessary overhead due to https://github.com/TuringLang/DynamicPPL.jl/blob/b82459a081c4b8925da3c0d97a6dc61687648ed3/src/compiler.jl#L106-L115 On the bright side it will make the inputs to `dot_tilde_assume!` and `dot_tilde_observe!` more consistent, so I'm a bit uncertain what the "right" choice is here. For now I've decided to just call `unwrap_right_left_vns` from within the `dot_tilde_observe!` for `PointwiseLikelihoodContext` as it only introduces an overhead to the `pointwise_loglikelihood` computation but nothing else. IMO this is way to go for this PR, but the above is something that should be given more thought later, e.g. introduce multi-index `VarName`.
1 parent 9d4a8f2 commit 64357e1

File tree

3 files changed

+94
-27
lines changed

3 files changed

+94
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.12.3"
3+
version = "0.12.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/loglikelihoods.jl

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,35 @@ end
9292
function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi)
9393
# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
9494
# we have to intercept the call to `dot_tilde_observe!`.
95-
logp = dot_tilde_observe(context.context, right, left, vi)
96-
acclogp!(vi, logp)
9795

98-
# Track loglikelihood value.
99-
push!(context, vn, logp)
96+
# We want to treat `.~` as a collection of independent observations,
97+
# hence we need the `logp` for each of them. Broadcasting the univariate
98+
# `tilde_obseve` does exactly this.
99+
logps = _pointwise_tilde_observe(context.context, right, left, vi)
100+
acclogp!(vi, sum(logps))
101+
102+
# Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`.
103+
_, _, vns = unwrap_right_left_vns(right, left, vn)
104+
for (vn, logp) in zip(vns, logps)
105+
# Track loglikelihood value.
106+
push!(context, vn, logp)
107+
end
100108

101109
return left
102110
end
103111

112+
# FIXME: This is really not a good approach since it needs to stay in sync with
113+
# the `dot_assume` implementations, but as things are _right now_ this is the best we can do.
114+
function _pointwise_tilde_observe(context, right, left, vi)
115+
return tilde_observe.(Ref(context), right, left, Ref(vi))
116+
end
117+
118+
function _pointwise_tilde_observe(
119+
context, right::MultivariateDistribution, left::AbstractMatrix, vi
120+
)
121+
return tilde_observe.(Ref(context), Ref(right), eachcol(left), Ref(vi))
122+
end
123+
104124
"""
105125
pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
106126
@@ -114,22 +134,30 @@ Currently, only `String` and `VarName` are supported.
114134
# Notes
115135
Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
116136
both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an
117-
*observation*) statements can be implemented in two ways:
137+
*observation*) statements can be implemented in three ways:
138+
1. using a `for` loop:
118139
```julia
119140
for i in eachindex(y)
120141
y[i] ~ Normal(μ, σ)
121142
end
122143
```
123-
or
144+
2. using `.~`:
145+
```julia
146+
y .~ Normal(μ, σ)
147+
```
148+
3. using `MvNormal`:
124149
```julia
125-
y ~ MvNormal(fill(μ, n), fill(σ, n))
150+
y ~ MvNormal(fill(μ, n), Diagonal(fill(σ, n)))
126151
```
127-
Unfortunately, just by looking at the latter statement, it's impossible to tell
128-
whether or not this is one *single* observation which is `n` dimensional OR if we
129-
have *multiple* 1-dimensional observations. Therefore, `loglikelihoods` will only
130-
work with the first example.
152+
153+
In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
154+
while in (3) `y` will be treated as a _single_ n-dimensional observation.
155+
156+
This is important to keep in mind, in particular if the computation is used
157+
for downstream computations.
131158
132159
# Examples
160+
## From chain
133161
```julia-repl
134162
julia> using DynamicPPL, Turing
135163
@@ -169,6 +197,27 @@ Dict{VarName,Array{Float64,2}} with 4 entries:
169197
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
170198
xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
171199
```
200+
201+
## Broadcasting
202+
Note that `x .~ Dist()` will treat `x` as a collection of
203+
_independent_ observations rather than as a single observation.
204+
205+
```jldoctest; setup = :(using Distributions)
206+
julia> @model function demo(x)
207+
x .~ Normal()
208+
end;
209+
210+
julia> m = demo([1.0, ]);
211+
212+
julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first(ℓ[@varname(x[1])])
213+
-1.4189385332046727
214+
215+
julia> m = demo([1.0; 1.0]);
216+
217+
julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])]))
218+
(-1.4189385332046727, -1.4189385332046727)
219+
```
220+
172221
"""
173222
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T}
174223
# Get the data by executing the model once

test/loglikelihoods.jl

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# `dot_assume` and `observe`
55
m = TV(undef, length(x))
66
m .~ Normal()
7-
return x ~ MvNormal(m, 0.5 * ones(length(x)))
7+
return x ~ MvNormal(m, 0.5)
88
end
99

1010
@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
@@ -13,13 +13,13 @@ end
1313
for i in eachindex(m)
1414
m[i] ~ Normal()
1515
end
16-
return x ~ MvNormal(m, 0.5 * ones(length(x)))
16+
return x ~ MvNormal(m, 0.5)
1717
end
1818

1919
@model function gdemo3(x=10 * ones(2))
2020
# Multivariate `assume` and `observe`
2121
m ~ MvNormal(length(x), 1.0)
22-
return x ~ MvNormal(m, 0.5 * ones(length(x)))
22+
return x ~ MvNormal(m, 0.5)
2323
end
2424

2525
@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
@@ -39,11 +39,11 @@ end
3939
return x .~ Normal(m, 0.5)
4040
end
4141

42-
# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV}
43-
# # `assume` and literal `observe`
44-
# m ~ MvNormal(length(x), 1.0)
45-
# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
46-
# end
42+
@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV}
43+
# `assume` and literal `observe`
44+
m ~ MvNormal(2, 1.0)
45+
return [10.0, 10.0] ~ MvNormal(m, 0.5)
46+
end
4747

4848
@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV}
4949
# `dot_assume` and literal `observe` with indexing
@@ -54,11 +54,11 @@ end
5454
end
5555
end
5656

57-
# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV}
58-
# # `assume` and literal `dot_observe`
59-
# m ~ Normal()
60-
# [10.0, ] .~ Normal(m, 0.5)
61-
# end
57+
@model function gdemo8(::Type{TV}=Vector{Float64}) where {TV}
58+
# `assume` and literal `dot_observe`
59+
m ~ Normal()
60+
return [10.0] .~ Normal(m, 0.5)
61+
end
6262

6363
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
6464
m = TV(undef, 2)
@@ -76,7 +76,7 @@ end
7676
end
7777

7878
@model function _likelihood_dot_observe(m, x)
79-
return x ~ MvNormal(m, 0.5 * ones(length(m)))
79+
return x ~ MvNormal(m, 0.5)
8080
end
8181

8282
@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
@@ -87,8 +87,26 @@ end
8787
@submodel _likelihood_dot_observe(m, x)
8888
end
8989

90+
@model function gdemo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV}
91+
m = TV(undef, length(x))
92+
m .~ Normal()
93+
94+
# Dotted observe for `Matrix`.
95+
return x .~ MvNormal(m, 0.5)
96+
end
97+
9098
const gdemo_models = (
91-
gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()
99+
gdemo1(),
100+
gdemo2(),
101+
gdemo3(),
102+
gdemo4(),
103+
gdemo5(),
104+
gdemo6(),
105+
gdemo7(),
106+
gdemo8(),
107+
gdemo9(),
108+
gdemo10(),
109+
gdemo11(),
92110
)
93111

94112
@testset "loglikelihoods.jl" begin

0 commit comments

Comments
 (0)