Skip to content

Commit 9043d65

Browse files
allow redefinition of inputs in logprob (#192)
Co-authored-by: David Widmann <[email protected]>
1 parent d480908 commit 9043d65

File tree

4 files changed

+23
-4
lines changed

4 files changed

+23
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.9.7"
3+
version = "0.9.8"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/context_implementations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi)
3030
return _tilde(rng, sampler, right, vn, vi)
3131
end
3232
function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi)
33-
if ctx.vars !== nothing
33+
if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn))
3434
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
3535
settrans!(vi, false, vn)
3636
end
@@ -169,7 +169,7 @@ function dot_tilde(
169169
inds,
170170
vi,
171171
)
172-
if ctx.vars !== nothing
172+
if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn))
173173
var = _getindex(getfield(ctx.vars, getsym(vn)), inds)
174174
vns, dist = get_vns_and_dist(right, var, vn)
175175
set_val!(vi, vns, dist, var)

src/prob_macro.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function Distributions.loglikelihood(
190190
if isdefined(right, :chain)
191191
# Element-wise likelihood for each value in chain
192192
chain = right.chain
193-
ctx = LikelihoodContext()
193+
ctx = LikelihoodContext(right)
194194
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
195195
logps = map(iters) do (sample_idx, chain_idx)
196196
setval!(vi, chain, sample_idx, chain_idx)

test/prob_macro.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,23 @@ Random.seed!(129)
128128
chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
129129
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2"
130130
end
131+
132+
@testset "issue190" begin
133+
@model function gdemo(x, y)
134+
s ~ InverseGamma(2, 3)
135+
m ~ Normal(0, sqrt(s))
136+
x ~ filldist(Normal(m, sqrt(s)), length(y))
137+
for i in 1:length(y)
138+
y[i] ~ Normal(x[i], sqrt(s))
139+
end
140+
end
141+
c = Chains(rand(10, 2), [:m, :s])
142+
model_gdemo = gdemo([1.0, 0.0], [1.5, 0.0])
143+
r1 = prob"y = [1.5] | chain=c, model = model_gdemo, x = [1.0]"
144+
r2 = map(c[:s]) do s
145+
# exp(logpdf(..)) not pdf because this is exactly what the prob"" macro does, so we test r1 == r2
146+
exp(logpdf(Normal(1, sqrt(s)), 1.5))
147+
end
148+
@test r1 == r2
149+
end
131150
end

0 commit comments

Comments
 (0)