Skip to content

Commit 7b7a3e2

Browse files
committed
Add NamedTuple methods for get/set/acclogp
1 parent 8241d12 commit 7b7a3e2

File tree

5 files changed

+108
-29
lines changed

5 files changed

+108
-29
lines changed

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ This release overhauls how VarInfo objects track variables such as the log joint
1515
- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
1616
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
1717
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
18+
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
19+
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The method with a single scalar value has been deprecated, and falls back on `setloglikelihood!!` or `accloglikelihood!!`. Corresponding setter/accumulator functions exist for the log prior as well.
1820

1921
## 0.36.0
2022

docs/src/api.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,12 +366,15 @@ NumProduce
366366
#### Accumulation of log-probabilities
367367

368368
```@docs
369-
getlogprior
370-
getloglikelihood
369+
getlogp
370+
setlogp!!
371+
acclogp!!
371372
getlogjoint
373+
getlogprior
372374
setlogprior!!
373-
setloglikelihood!!
374375
acclogprior!!
376+
getloglikelihood
377+
setloglikelihood!!
375378
accloglikelihood!!
376379
resetlogp!!
377380
```

src/abstract_varinfo.jl

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,17 @@ Return the log of the joint probability of the observed data and parameters in `
9898
See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref).
9999
"""
100100
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi)
101+
102+
"""
103+
getlogp(vi::AbstractVarInfo)
104+
105+
Return a NamedTuple of the log prior and log likelihood probabilities.
106+
107+
The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an
108+
error will be thrown.
109+
"""
101110
function getlogp(vi::AbstractVarInfo)
102-
Base.depwarn("getlogp is deprecated, use getlogjoint instead", :getlogp)
103-
return getlogjoint(vi)
111+
return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi))
104112
end
105113

106114
"""
@@ -198,23 +206,31 @@ See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@re
198206
setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp))
199207

200208
"""
201-
setlogp!!(vi::AbstractVarInfo, logp)
209+
setlogp!!(vi::AbstractVarInfo, logp::NamedTuple)
202210
203-
Set the log of the joint probability of the observed data and parameters sampled in
204-
`vi` to `logp`, mutating if it makes sense.
211+
Set both the log prior and the log likelihood probabilities in `vi`.
212+
213+
`logp` should have fields `logprior` and `loglikelihood` and no other fields.
205214
206215
See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref).
207216
"""
208-
function setlogp!!(vi::AbstractVarInfo, logp)
209-
Base.depwarn(
210-
"setlogp!! is deprecated, use setlogprior!! or setloglikelihood!! instead",
211-
:setlogp!!,
212-
)
213-
vi = setlogprior!!(vi, zero(logp))
214-
vi = setloglikelihood!!(vi, logp)
217+
function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
218+
if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior))
219+
error("logp must have the fields logprior and loglikelihood and no other fields.")
220+
end
221+
vi = setlogprior!!(vi, logp.logprior)
222+
vi = setloglikelihood!!(vi, logp.loglikelihood)
215223
return vi
216224
end
217225

226+
function setlogp!!(vi::AbstractVarInfo, logp::Number)
227+
depwarn(
228+
"`setlogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `setloglikelihood!!(vi, logp)` instead.",
229+
:setlogp,
230+
)
231+
return setloglikelihood!!(vi, logp)
232+
end
233+
218234
"""
219235
getacc(vi::AbstractVarInfo, ::Val{accname})
220236
@@ -303,15 +319,34 @@ function accloglikelihood!!(vi::AbstractVarInfo, logp)
303319
end
304320

305321
"""
306-
acclogp!!(vi::AbstractVarInfo, logp)
322+
acclogp!!(vi::AbstractVarInfo, logp::NamedTuple)
323+
324+
Add to both the log prior and the log likelihood probabilities in `vi`.
307325
308-
Add `logp` to the value of the log of the joint probability of the observed data and
309-
parameters sampled in `vi`, mutating if it makes sense.
326+
`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields.
310327
"""
311-
function acclogp!!(vi::AbstractVarInfo, logp)
312-
Base.depwarn(
313-
"acclogp!! is deprecated, use acclogprior!! or accloglikelihood!! instead",
314-
:acclogp!!,
328+
function acclogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names}
329+
if !(
330+
names == (:logprior, :loglikelihood) ||
331+
names == (:loglikelihood, :logprior) ||
332+
names == (:logprior,) ||
333+
names == (:loglikelihood,)
334+
)
335+
error("logp must have fields logprior and/or loglikelihood and no other fields.")
336+
end
337+
if haskey(logp, :logprior)
338+
vi = acclogprior!!(vi, logp.logprior)
339+
end
340+
if haskey(logp, :loglikelihood)
341+
vi = accloglikelihood!!(vi, logp.loglikelihood)
342+
end
343+
return vi
344+
end
345+
346+
function acclogp!!(vi::AbstractVarInfo, logp::Number)
347+
depwarn(
348+
"`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.",
349+
:acclogp,
315350
)
316351
return accloglikelihood!!(vi, logp)
317352
end

test/submodels.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ using Test
3535
@test model()[1] == x_val
3636
# Test that the logp was correctly set
3737
vi = VarInfo(model)
38-
@test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)])
38+
@test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)])
3939
# Check the keys
4040
@test Set(keys(VarInfo(model))) == Set([@varname(a.y)])
4141
end
@@ -67,7 +67,7 @@ using Test
6767
@test model()[1] == x_val
6868
# Test that the logp was correctly set
6969
vi = VarInfo(model)
70-
@test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)])
70+
@test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)])
7171
# Check the keys
7272
@test Set(keys(VarInfo(model))) == Set([@varname(y)])
7373
end
@@ -99,7 +99,7 @@ using Test
9999
@test model()[1] == x_val
100100
# Test that the logp was correctly set
101101
vi = VarInfo(model)
102-
@test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)])
102+
@test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)])
103103
# Check the keys
104104
@test Set(keys(VarInfo(model))) == Set([@varname(b.y)])
105105
end
@@ -148,7 +148,7 @@ using Test
148148
# No conditioning
149149
vi = VarInfo(h())
150150
@test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)])
151-
@test getlogp(vi) ==
151+
@test getlogjoint(vi) ==
152152
logpdf(Normal(), vi[@varname(a.b.x)]) +
153153
logpdf(Normal(), vi[@varname(a.b.y)])
154154

@@ -174,7 +174,7 @@ using Test
174174
@testset "$name" for (name, model) in models
175175
vi = VarInfo(model)
176176
@test Set(keys(vi)) == Set([@varname(a.b.y)])
177-
@test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)])
177+
@test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)])
178178
end
179179
end
180180
end

test/varinfo.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,29 +175,68 @@ end
175175
vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi)))
176176
@test getlogprior(vi) == lp_a + lp_b
177177
@test getloglikelihood(vi) == lp_c + lp_d
178+
@test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d)
179+
@test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d
178180
@test get_num_produce(vi) == 2
181+
@test begin
182+
vi = acclogprior!!(vi, 1.0)
183+
getlogprior(vi) == lp_a + lp_b + 1.0
184+
end
185+
@test begin
186+
vi = accloglikelihood!!(vi, 1.0)
187+
getloglikelihood(vi) == lp_c + lp_d + 1.0
188+
end
189+
@test begin
190+
vi = setlogprior!!(vi, -1.0)
191+
getlogprior(vi) == -1.0
192+
end
193+
@test begin
194+
vi = setloglikelihood!!(vi, -1.0)
195+
getloglikelihood(vi) == -1.0
196+
end
197+
@test begin
198+
vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0))
199+
getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0)
200+
end
201+
@test begin
202+
vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0))
203+
getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0)
204+
end
205+
@test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi)
179206

180207
vi = last(
181208
DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPrior(),)))
182209
)
183210
@test getlogprior(vi) == lp_a + lp_b
184211
@test_throws "has no field LogLikelihood" getloglikelihood(vi)
212+
@test_throws "has no field LogLikelihood" getlogp(vi)
185213
@test_throws "has no field LogLikelihood" getlogjoint(vi)
186214
@test_throws "has no field NumProduce" get_num_produce(vi)
215+
@test begin
216+
vi = acclogprior!!(vi, 1.0)
217+
getlogprior(vi) == lp_a + lp_b + 1.0
218+
end
219+
@test begin
220+
vi = setlogprior!!(vi, -1.0)
221+
getlogprior(vi) == -1.0
222+
end
223+
@test_throws "has no field LogLikelihood" setlogp!!(getlogp(vi))
187224

188225
vi = last(
189226
DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), (NumProduce(),)))
190227
)
191228
@test_throws "has no field LogPrior" getlogprior(vi)
192-
@test_throws "has no field LogPrior" getlogjoint(vi)
193229
@test_throws "has no field LogLikelihood" getloglikelihood(vi)
230+
@test_throws "has no field LogPrior" getlogp(vi)
231+
@test_throws "has no field LogPrior" getlogjoint(vi)
194232
@test get_num_produce(vi) == 2
195233

196234
# Test evaluating without any accumulators.
197235
vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ())))
198236
@test_throws "has no field LogPrior" getlogprior(vi)
199-
@test_throws "has no field LogPrior" getlogjoint(vi)
200237
@test_throws "has no field LogLikelihood" getloglikelihood(vi)
238+
@test_throws "has no field LogPrior" getlogp(vi)
239+
@test_throws "has no field LogPrior" getlogjoint(vi)
201240
@test_throws "has no field NumProduce" get_num_produce(vi)
202241
@test_throws "has no field NumProduce" reset_num_produce!!(vi)
203242
end

0 commit comments

Comments
 (0)