Skip to content

Commit 14f4788

Browse files
committed
Add @addlogprior! and @addloglikelihood!
1 parent 00ef0cf commit 14f4788

File tree

5 files changed

+181
-13
lines changed

5 files changed

+181
-13
lines changed

docs/src/api.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,12 @@ returned(::Model)
160160

161161
## Utilities
162162

163-
It is possible to manually increase (or decrease) the accumulated log likelihood from within a model function.
163+
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.
164164

165165
```@docs
166166
@addlogprob!
167+
@addloglikelihood!
168+
@addlogprior!
167169
```
168170

169171
Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ export AbstractVarInfo,
127127
to_submodel,
128128
# Convenience macros
129129
@addlogprob!,
130+
@addlogprior!,
131+
@addloglikelihood!,
130132
@submodel,
131133
value_iterator_from_chain,
132134
check_model,

src/abstract_varinfo.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,18 @@ function accloglikelihood!!(vi::AbstractVarInfo, logp)
320320
end
321321

322322
"""
323-
acclogp!!(vi::AbstractVarInfo, logp::NamedTuple)
323+
acclogp!!(vi::AbstractVarInfo, logp::NamedTuple; ignore_missing_accumulator::Bool=false)
324324
325325
Add to both the log prior and the log likelihood probabilities in `vi`.
326326
327327
`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields.
328+
329+
By default if the necessary accumulators are not in `vi` an error is thrown. If
330+
`ignore_missing_accumulator` is set to `true` then this is silently ignored instead.
328331
"""
329-
function acclogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
332+
function acclogp!!(
333+
vi::AbstractVarInfo, logp::NamedTuple{names}; ignore_missing_accumulator=false
334+
) where {names}
330335
if !(
331336
names == (:logprior, :loglikelihood) ||
332337
names == (:loglikelihood, :logprior) ||
@@ -335,17 +340,19 @@ function acclogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
335340
)
336341
error("logp must have fields logprior and/or loglikelihood and no other fields.")
337342
end
338-
if haskey(logp, :logprior)
343+
if haskey(logp, :logprior) &&
344+
(!ignore_missing_accumulator || hasacc(vi, Val(:LogPrior)))
339345
vi = acclogprior!!(vi, logp.logprior)
340346
end
341-
if haskey(logp, :loglikelihood)
347+
if haskey(logp, :loglikelihood) &&
348+
(!ignore_missing_accumulator || hasacc(vi, Val(:LogLikelihood)))
342349
vi = accloglikelihood!!(vi, logp.loglikelihood)
343350
end
344351
return vi
345352
end
346353

347354
function acclogp!!(vi::AbstractVarInfo, logp::Number)
348-
depwarn(
355+
Base.depwarn(
349356
"`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.",
350357
:acclogp,
351358
)

src/utils.jl

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,86 @@ const LogProbType = float(Real)
1818
"""
1919
@addlogprob!(ex)
2020
21+
Add a term to the log joint.
22+
23+
If `ex` evaluates to a `NamedTuple` with keys `:loglikelihood` and/or `:logprior`, the
24+
values are added to the log likelihood and log prior respectively.
25+
26+
If `ex` evaluates to a number it is added to the log likelihood. This use is deprecated
27+
and should be replaced with either the `NamedTuple` version or calls to
28+
[`@addloglikelihood!`](@ref).
29+
30+
See also [`@addloglikelihood!`](@ref), [`@addlogprior!`](@ref).
31+
32+
# Examples
33+
34+
```jldoctest; setup = :(using Distributions)
35+
julia> mylogjoint(x, μ) = (; loglikelihood=loglikelihood(Normal(μ, 1), x), logprior=1.0);
36+
37+
julia> @model function demo(x)
38+
μ ~ Normal()
39+
@addlogprob! mylogjoint(x, μ)
40+
end;
41+
42+
julia> x = [1.3, -2.1];
43+
44+
julia> loglikelihood(demo(x), (μ=0.2,)) ≈ mylogjoint(x, 0.2).loglikelihood
45+
true
46+
47+
julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogjoint(x, 0.2).logprior
48+
true
49+
```
50+
51+
and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328):
52+
53+
```jldoctest; setup = :(using Distributions, LinearAlgebra)
54+
julia> @model function demo(x)
55+
m ~ MvNormal(zero(x), I)
56+
if dot(m, x) < 0
57+
@addlogprob! (; loglikelihood=-Inf)
58+
# Exit the model evaluation early
59+
return
60+
end
61+
x ~ MvNormal(m, I)
62+
return
63+
end;
64+
65+
julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf
66+
true
67+
```
68+
"""
69+
macro addlogprob!(ex)
70+
return quote
71+
val = $(esc(ex))
72+
vi = $(esc(:(__varinfo__)))
73+
if val isa Number
74+
Base.depwarn(
75+
"""
76+
@addlogprob! with a single number argument is deprecated. Please use
77+
@addlogprob! (; loglikelihood=x) or @addloglikelihood! instead.
78+
""",
79+
:addlogprob!,
80+
)
81+
if hasacc(vi, Val(:LogLikelihood))
82+
$(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), val)
83+
end
84+
elseif !isa(val, NamedTuple)
85+
error("logp must be a NamedTuple.")
86+
else
87+
$(esc(:(__varinfo__))) = acclogp!!(
88+
$(esc(:(__varinfo__))), val; ignore_missing_accumulator=true
89+
)
90+
end
91+
end
92+
end
93+
94+
"""
95+
@addloglikelihood!(ex)
96+
2197
Add the result of the evaluation of `ex` to the log likelihood.
2298
99+
See also [`@addlogprob!`](@ref), [`@addlogprior!`](@ref).
100+
23101
# Examples
24102
25103
This macro allows you to [include arbitrary terms in the likelihood](https://github.com/TuringLang/Turing.jl/issues/1332)
@@ -29,7 +107,7 @@ julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x);
29107
30108
julia> @model function demo(x)
31109
μ ~ Normal()
32-
@addlogprob! myloglikelihood(x, μ)
110+
@addloglikelihood! myloglikelihood(x, μ)
33111
end;
34112
35113
julia> x = [1.3, -2.1];
@@ -44,7 +122,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328):
44122
julia> @model function demo(x)
45123
m ~ MvNormal(zero(x), I)
46124
if dot(m, x) < 0
47-
@addlogprob! -Inf
125+
@addloglikelihood! -Inf
48126
# Exit the model evaluation early
49127
return
50128
end
@@ -56,10 +134,43 @@ julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf
56134
true
57135
```
58136
"""
59-
macro addlogprob!(ex)
137+
macro addloglikelihood!(ex)
138+
return quote
139+
if hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood))
140+
$(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex)))
141+
end
142+
end
143+
end
144+
145+
"""
146+
@addlogprior!(ex)
147+
148+
Add the result of the evaluation of `ex` to the log prior.
149+
150+
See also [`@addloglikelihood!`](@ref), [`@addlogprob!`](@ref).
151+
152+
# Examples
153+
154+
This macro allows you to include arbitrary terms in the prior.
155+
156+
```jldoctest; setup = :(using Distributions)
157+
julia> mylogpriorextraterm(μ) = μ > 0 ? -1.0 : 0.0;
158+
159+
julia> @model function demo(x)
160+
μ ~ Normal()
161+
@addlogprior! mylogpriorextraterm(μ)
162+
end;
163+
164+
julia> x = [1.3, -2.1];
165+
166+
julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogpriorextraterm(0.2)
167+
true
168+
```
169+
"""
170+
macro addlogprior!(ex)
60171
return quote
61-
if $hasacc($(esc(:(__varinfo__))), Val(:LogLikelihood))
62-
$(esc(:(__varinfo__))) = $accloglikelihood!!($(esc(:(__varinfo__))), $(esc(ex)))
172+
if hasacc($(esc(:(__varinfo__))), Val(:LogPrior))
173+
$(esc(:(__varinfo__))) = acclogprior!!($(esc(:(__varinfo__))), $(esc(ex)))
63174
end
64175
end
65176
end

test/utils.jl

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,56 @@
66
return global lp_after = getlogjoint(__varinfo__)
77
end
88

9-
model = testmodel()
10-
varinfo = VarInfo(model)
9+
varinfo = VarInfo(testmodel())
1110
@test iszero(lp_before)
1211
@test getlogjoint(varinfo) == lp_after == 42
12+
@test getloglikelihood(varinfo) == 42
13+
14+
@model function testmodel_nt()
15+
global lp_before = getlogjoint(__varinfo__)
16+
@addlogprob! (; logprior=(pi + 1), loglikelihood=42)
17+
return global lp_after = getlogjoint(__varinfo__)
18+
end
19+
20+
varinfo = VarInfo(testmodel_nt())
21+
@test iszero(lp_before)
22+
@test getlogjoint(varinfo) == lp_after == 42 + 1 + pi
23+
@test getloglikelihood(varinfo) == 42
24+
@test getlogprior(varinfo) == pi + 1
25+
26+
@model function testmodel_nt2()
27+
global lp_before = getlogjoint(__varinfo__)
28+
llh_nt = (; loglikelihood=42)
29+
@addlogprob! llh_nt
30+
return global lp_after = getlogjoint(__varinfo__)
31+
end
32+
33+
varinfo = VarInfo(testmodel_nt2())
34+
@test iszero(lp_before)
35+
@test getlogjoint(varinfo) == lp_after == 42
36+
@test getloglikelihood(varinfo) == 42
37+
38+
@model function testmodel_likelihood()
39+
global lp_before = getlogjoint(__varinfo__)
40+
@addloglikelihood! 42
41+
return global lp_after = getlogjoint(__varinfo__)
42+
end
43+
44+
varinfo = VarInfo(testmodel_likelihood())
45+
@test iszero(lp_before)
46+
@test getlogjoint(varinfo) == lp_after == 42
47+
@test getloglikelihood(varinfo) == 42
48+
49+
@model function testmodel_prior()
50+
global lp_before = getlogjoint(__varinfo__)
51+
@addlogprior! 42
52+
return global lp_after = getlogjoint(__varinfo__)
53+
end
54+
55+
varinfo = VarInfo(testmodel_prior())
56+
@test iszero(lp_before)
57+
@test getlogjoint(varinfo) == lp_after == 42
58+
@test getlogprior(varinfo) == 42
1359
end
1460

1561
@testset "getargs_dottilde" begin

0 commit comments

Comments
 (0)