Skip to content

Commit 7aa9ebe

Browse files
committed
undeprecate pointwise_loglikelihoods and implement pointwise_prior_logdensities
mostly taken from TuringLang#669
1 parent 656a757 commit 7aa9ebe

File tree

6 files changed

+64
-43
lines changed

6 files changed

+64
-43
lines changed

src/DynamicPPL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ export AbstractVarInfo,
115115
# Convenience functions
116116
logprior,
117117
logjoint,
118-
pointwise_loglikelihoods,
118+
pointwise_prior_logdensities,
119119
pointwise_logdensities,
120+
pointwise_loglikelihoods,
120121
condition,
121122
decondition,
122123
fix,
@@ -190,7 +191,6 @@ include("logdensityfunction.jl")
190191
include("model_utils.jl")
191192
include("extract_priors.jl")
192193
include("values_as_in_model.jl")
193-
include("deprecated.jl")
194194

195195
include("debug_utils.jl")
196196
using .DebugUtils

src/deprecated.jl

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/pointwise_logdensities.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,62 @@ function pointwise_logdensities(model::Model,
322322
end
323323

324324

325+
326+
327+
"""
328+
pointwise_loglikelihoods(model, chain[, keytype, context])
329+
Compute the pointwise log-likelihoods of the model given the chain.
330+
This is the same as `pointwise_logdensities(model, chain, context)`, but only
331+
including the likelihood terms.
332+
See also: [`pointwise_logdensities`](@ref).
333+
"""
334+
function pointwise_loglikelihoods(
335+
model::Model,
336+
chain,
337+
keytype::Type{T}=String,
338+
context::AbstractContext=LikelihoodContext(),
339+
) where {T}
340+
if !(leafcontext(context) isa LikelihoodContext)
341+
throw(ArgumentError("Leaf context should be a LikelihoodContext"))
342+
end
343+
344+
return pointwise_logdensities(model, chain, T, context)
345+
end
346+
347+
function pointwise_loglikelihoods(
348+
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext()
349+
)
350+
if !(leafcontext(context) isa LikelihoodContext)
351+
throw(ArgumentError("Leaf context should be a LikelihoodContext"))
352+
end
353+
354+
return pointwise_logdensities(model, varinfo, context)
355+
end
356+
357+
"""
358+
pointwise_prior_logdensities(model, chain[, keytype, context])
359+
Compute the pointwise log-prior-densities of the model given the chain.
360+
This is the same as `pointwise_logdensities(model, chain, context)`, but only
361+
including the prior terms.
362+
See also: [`pointwise_logdensities`](@ref).
363+
"""
364+
function pointwise_prior_logdensities(
365+
model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext()
366+
) where {T}
367+
if !(leafcontext(context) isa PriorContext)
368+
throw(ArgumentError("Leaf context should be a PriorContext"))
369+
end
370+
371+
return pointwise_logdensities(model, chain, T, context)
372+
end
373+
374+
function pointwise_prior_logdensities(
375+
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext()
376+
)
377+
if !(leafcontext(context) isa PriorContext)
378+
throw(ArgumentError("Leaf context should be a PriorContext"))
379+
end
380+
381+
return pointwise_logdensities(model, varinfo, context)
382+
end
383+

test/deprecated.jl

Lines changed: 0 additions & 28 deletions
This file was deleted.

test/pointwise_logdensities.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
logp_true = logprior(m, vi)
2828

2929
# Compute the pointwise loglikelihoods.
30-
lls = pointwise_logdensities(m, vi, likelihood_context)
30+
lls = pointwise_loglikelihoods(m, vi)
3131
#lls2 = pointwise_loglikelihoods(m, vi)
3232
if isempty(lls)
3333
# One of the models with literal observations, so we just skip.
@@ -38,7 +38,7 @@
3838
end
3939

4040
# Compute the pointwise logdensities of the priors.
41-
lps_prior = pointwise_logdensities(m, vi, prior_context)
41+
lps_prior = pointwise_prior_logdensities(m, vi)
4242
logp = sum(sum, values(lps_prior))
4343
logp1 = getlogp(vi)
4444
@test !isfinite(logp_true) || logp logp_true
@@ -56,6 +56,7 @@
5656
end
5757
end
5858

59+
5960
@testset "pointwise_logdensities chain" begin
6061
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
6162
s ~ InverseGamma(2, 3)

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ include("test_util.jl")
6060
include("pointwise_logdensities.jl")
6161

6262
include("lkj.jl")
63-
64-
include("deprecated.jl")
6563
end
6664

6765
@testset "compat" begin

0 commit comments

Comments
 (0)