Skip to content

Commit 5376534

Browse files
authored
Replace elementwise_loglikelihoods with pointwise_loglikelihoods (#179)
1 parent 37ec450 commit 5376534

File tree

8 files changed

+50
-35
lines changed

8 files changed

+50
-35
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.9.4"
3+
version = "0.9.5"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/DynamicPPL.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ export AbstractVarInfo,
9191
# Convenience functions
9292
logprior,
9393
logjoint,
94-
elementwise_loglikelihoods,
94+
pointwise_loglikelihoods,
9595
# Convenience macros
9696
@addlogprob!
9797

@@ -122,4 +122,6 @@ include("prob_macro.jl")
122122
include("compat/ad.jl")
123123
include("loglikelihoods.jl")
124124

125+
include("deprecations.jl")
126+
125127
end # module

src/deprecations.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@deprecate getmissing(model) getmissings(model)
2+
3+
# `@deprecate` doesn't work with qualified function names,
4+
# so we use the following hack
5+
const _base_in = Base.in
6+
@deprecate _base_in(vn::VarName, space::Tuple) inspace(vn, space)
7+
8+
@deprecate elementwise_loglikelihoods(
9+
model::Model, chain,
10+
) pointwise_loglikelihoods(
11+
model, chain, String,
12+
)
13+
@deprecate elementwise_loglikelihoods(
14+
model::Model, chain, ::Type{T},
15+
) where {T} pointwise_loglikelihoods(
16+
model, chain, T,
17+
)
18+
@deprecate elementwise_loglikelihoods(
19+
model::Model, varinfo::AbstractVarInfo,
20+
) pointwise_loglikelihoods(
21+
model, varinfo,
22+
)

src/loglikelihoods.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# Context version
2-
struct ElementwiseLikelihoodContext{A, Ctx} <: AbstractContext
2+
struct PointwiseLikelihoodContext{A, Ctx} <: AbstractContext
33
loglikelihoods::A
44
ctx::Ctx
55
end
66

7-
function ElementwiseLikelihoodContext(
7+
function PointwiseLikelihoodContext(
88
likelihoods = Dict{VarName, Vector{Float64}}(),
99
ctx::AbstractContext = LikelihoodContext()
1010
)
11-
return ElementwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx)
11+
return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx)
1212
end
1313

1414
function Base.push!(
15-
ctx::ElementwiseLikelihoodContext{Dict{VarName, Vector{Float64}}},
15+
ctx::PointwiseLikelihoodContext{Dict{VarName, Vector{Float64}}},
1616
vn::VarName,
1717
logp::Real
1818
)
@@ -22,15 +22,15 @@ function Base.push!(
2222
end
2323

2424
function Base.push!(
25-
ctx::ElementwiseLikelihoodContext{Dict{VarName, Float64}},
25+
ctx::PointwiseLikelihoodContext{Dict{VarName, Float64}},
2626
vn::VarName,
2727
logp::Real
2828
)
2929
ctx.loglikelihoods[vn] = logp
3030
end
3131

3232
function Base.push!(
33-
ctx::ElementwiseLikelihoodContext{Dict{String, Vector{Float64}}},
33+
ctx::PointwiseLikelihoodContext{Dict{String, Vector{Float64}}},
3434
vn::VarName,
3535
logp::Real
3636
)
@@ -40,15 +40,15 @@ function Base.push!(
4040
end
4141

4242
function Base.push!(
43-
ctx::ElementwiseLikelihoodContext{Dict{String, Float64}},
43+
ctx::PointwiseLikelihoodContext{Dict{String, Float64}},
4444
vn::VarName,
4545
logp::Real
4646
)
4747
ctx.loglikelihoods[string(vn)] = logp
4848
end
4949

5050
function Base.push!(
51-
ctx::ElementwiseLikelihoodContext{Dict{String, Vector{Float64}}},
51+
ctx::PointwiseLikelihoodContext{Dict{String, Vector{Float64}}},
5252
vn::String,
5353
logp::Real
5454
)
@@ -58,26 +58,26 @@ function Base.push!(
5858
end
5959

6060
function Base.push!(
61-
ctx::ElementwiseLikelihoodContext{Dict{String, Float64}},
61+
ctx::PointwiseLikelihoodContext{Dict{String, Float64}},
6262
vn::String,
6363
logp::Real
6464
)
6565
ctx.loglikelihoods[vn] = logp
6666
end
6767

6868

69-
function tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, vn, inds, vi)
69+
function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi)
7070
return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi)
7171
end
7272

73-
function dot_tilde_assume(rng, ctx::ElementwiseLikelihoodContext, sampler, right, left, vn, inds, vi)
73+
function dot_tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi)
7474
value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi)
7575
acclogp!(vi, logp)
7676
return value
7777
end
7878

7979

80-
function tilde_observe(ctx::ElementwiseLikelihoodContext, sampler, right, left, vname, vinds, vi)
80+
function tilde_observe(ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi)
8181
# This is slightly unfortunate since it is not completely generic...
8282
# Ideally we would call `tilde_observe` recursively but then we don't get the
8383
# loglikelihood value.
@@ -92,7 +92,7 @@ end
9292

9393

9494
"""
95-
elementwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
95+
pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String)
9696
9797
Runs `model` on each sample in `chain` returning a `Dict{String, Matrix{Float64}}`
9898
with keys corresponding to symbols of the observations, and values being matrices
@@ -138,37 +138,37 @@ julia> model = demo(randn(3), randn());
138138
139139
julia> chain = sample(model, MH(), 10);
140140
141-
julia> elementwise_loglikelihoods(model, chain)
141+
julia> pointwise_loglikelihoods(model, chain)
142142
Dict{String,Array{Float64,2}} with 4 entries:
143143
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
144144
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
145145
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
146146
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
147147
148-
julia> elementwise_loglikelihoods(model, chain, String)
148+
julia> pointwise_loglikelihoods(model, chain, String)
149149
Dict{String,Array{Float64,2}} with 4 entries:
150150
"xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
151151
"xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
152152
"xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
153153
"y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
154154
155-
julia> elementwise_loglikelihoods(model, chain, VarName)
155+
julia> pointwise_loglikelihoods(model, chain, VarName)
156156
Dict{VarName,Array{Float64,2}} with 4 entries:
157157
xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359]
158158
y => [-1.51265; -0.914129; … ; -1.5499; -1.5499]
159159
xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333]
160160
xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251]
161161
```
162162
"""
163-
function elementwise_loglikelihoods(
163+
function pointwise_loglikelihoods(
164164
model::Model,
165165
chain,
166166
keytype::Type{T} = String
167167
) where {T}
168168
# Get the data by executing the model once
169169
spl = SampleFromPrior()
170170
vi = VarInfo(model)
171-
ctx = ElementwiseLikelihoodContext(Dict{T, Vector{Float64}}())
171+
ctx = PointwiseLikelihoodContext(Dict{T, Vector{Float64}}())
172172

173173
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
174174
for (sample_idx, chain_idx) in iters
@@ -188,8 +188,8 @@ function elementwise_loglikelihoods(
188188
return loglikelihoods
189189
end
190190

191-
function elementwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
192-
ctx = ElementwiseLikelihoodContext(Dict{VarName, Float64}())
191+
function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo)
192+
ctx = PointwiseLikelihoodContext(Dict{VarName, Float64}())
193193
model(varinfo, SampleFromPrior(), ctx)
194194
return ctx.loglikelihoods
195195
end

src/model.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,6 @@ Get a tuple of the names of the missing arguments of the `model`.
177177
"""
178178
getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missings
179179

180-
getmissing(model::Model) = getmissings(model)
181-
@deprecate getmissing(model) getmissings(model)
182-
183180
"""
184181
nameof(model::Model)
185182

src/varname.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,6 @@ inspace(vn, space::Tuple) = vn in space
8282
inspace(vn::VarName, space::Tuple{}) = true
8383
inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space)
8484

85-
@noinline function Base.in(vn::VarName, space::Tuple)
86-
Base.depwarn("`Base.in(vn::VarName, space::Tuple)` is deprecated, use `inspace(vn, space)` instead.",
87-
nameof(Base.in))
88-
return inspace(vn, space)
89-
end
90-
9185
_in(vn::VarName, s::Symbol) = getsym(vn) == s
9286
_in(vn::VarName, s::VarName) = subsumes(s, vn)
9387

test/Turing/Turing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ end
6767
# Exports #
6868
###########
6969
# `using` statements for stuff to re-export
70-
using DynamicPPL: elementwise_loglikelihoods, generated_quantities, logprior, logjoint
70+
using DynamicPPL: pointwise_loglikelihoods, generated_quantities, logprior, logjoint
7171
using StatsBase: predict
7272

7373
# Turing essentials - modelling macros and inference algorithms
@@ -122,7 +122,7 @@ export @model, # modelling
122122
arraydist,
123123

124124
predict,
125-
elementwise_loglikelihoods,
125+
pointwise_loglikelihoods,
126126
genereated_quantities,
127127
logprior,
128128
logjoint

test/loglikelihoods.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using .Turing
1515
y = randn();
1616
model = demo(xs, y);
1717
chain = sample(model, MH(), MCMCThreads(), 100, 2);
18-
var_to_likelihoods = elementwise_loglikelihoods(model, chain)
18+
var_to_likelihoods = pointwise_loglikelihoods(model, chain)
1919
@test haskey(var_to_likelihoods, "xs[1]")
2020
@test haskey(var_to_likelihoods, "xs[2]")
2121
@test haskey(var_to_likelihoods, "xs[3]")
@@ -31,7 +31,7 @@ using .Turing
3131
end
3232

3333
var_info = VarInfo(model)
34-
results = DynamicPPL.elementwise_loglikelihoods(model, var_info)
34+
results = pointwise_loglikelihoods(model, var_info)
3535
var_to_likelihoods = Dict(string(vn) =>for (vn, ℓ) in results)
3636
s, m = var_info[SampleFromPrior()]
3737
@test logpdf(Normal(m, s), xs[1]) == var_to_likelihoods["xs[1]"]

0 commit comments

Comments
 (0)