Skip to content

Commit b0bb31e

Browse files
authored
Update the currently buggy and incorrect tilde overloads in mh.jl (#2360)
1 parent 452d0d0 commit b0bb31e

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

src/mcmc/mh.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -442,42 +442,45 @@ end
442442
####
443443
#### Compiler interface, i.e. tilde operators.
444444
####
445-
function DynamicPPL.assume(rng, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi)
445+
function DynamicPPL.assume(
446+
rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi
447+
)
448+
# Just defer to `SampleFromPrior`.
449+
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
450+
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
446451
DynamicPPL.updategid!(vi, vn, spl)
447-
r = vi[vn]
448-
return r, logpdf_with_trans(dist, r, istrans(vi, vn)), vi
452+
# Return.
453+
return retval
449454
end
450455

451456
function DynamicPPL.dot_assume(
452457
rng,
453458
spl::Sampler{<:MH},
454459
dist::MultivariateDistribution,
455-
vn::VarName,
460+
vns::AbstractVector{<:VarName},
456461
var::AbstractMatrix,
457-
vi,
462+
vi::AbstractVarInfo,
458463
)
459-
@assert dim(dist) == size(var, 1)
460-
getvn = i -> VarName(vn, vn.indexing * "[:,$i]")
461-
vns = getvn.(1:size(var, 2))
462-
DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl))
463-
r = vi[vns]
464-
var .= r
465-
return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))), vi
464+
# Just defer to `SampleFromPrior`.
465+
retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dist, vns[1], var, vi)
466+
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
467+
DynamicPPL.updategid!.((vi,), vns, (spl,))
468+
# Return.
469+
return retval
466470
end
467471
function DynamicPPL.dot_assume(
468472
rng,
469473
spl::Sampler{<:MH},
470474
dists::Union{Distribution,AbstractArray{<:Distribution}},
471-
vn::VarName,
475+
vns::AbstractArray{<:VarName},
472476
var::AbstractArray,
473-
vi,
477+
vi::AbstractVarInfo,
474478
)
475-
getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]")
476-
vns = getvn.(CartesianIndices(var))
477-
DynamicPPL.updategid!.(Ref(vi), vns, Ref(spl))
478-
r = reshape(vi[vec(vns)], size(var))
479-
var .= r
480-
return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))), vi
479+
# Just defer to `SampleFromPrior`.
480+
retval = DynamicPPL.dot_assume(rng, SampleFromPrior(), dists, vns, var, vi)
481+
# Update the Gibbs IDs because they might have been assigned in the `SampleFromPrior` call.
482+
DynamicPPL.updategid!.((vi,), vns, (spl,))
483+
return retval
481484
end
482485

483486
function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)

0 commit comments

Comments
 (0)