-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Lines 1 to 19 in 0810e14
# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` | |
# That would let us use all defaults for Sampler, combine it with other samplers etc. | |
""" | |
SampleFromUniform | |
Sampling algorithm that samples unobserved random variables from a uniform distribution. | |
# References | |
[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) | |
""" | |
struct SampleFromUniform <: AbstractSampler end | |
""" | |
SampleFromPrior | |
Sampling algorithm that samples unobserved random variables from their prior distribution. | |
""" | |
struct SampleFromPrior <: AbstractSampler end |
DynamicPPL.jl/src/context_implementations.jl
Lines 212 to 260 in 0810e14
# TODO: Remove this thing. | |
# SampleFromPrior and SampleFromUniform | |
function assume( | |
rng::Random.AbstractRNG, | |
sampler::Union{SampleFromPrior,SampleFromUniform}, | |
dist::Distribution, | |
vn::VarName, | |
vi::VarInfoOrThreadSafeVarInfo, | |
) | |
if haskey(vi, vn) | |
# Always overwrite the parameters with new ones for `SampleFromUniform`. | |
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") | |
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us | |
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure | |
# if that's okay. | |
unset_flag!(vi, vn, "del", true) | |
r = init(rng, dist, sampler) | |
f = to_maybe_linked_internal_transform(vi, vn, dist) | |
# TODO(mhauru) This should probably be call a function called setindex_internal! | |
# Also, if we use !! we shouldn't ignore the return value. | |
BangBang.setindex!!(vi, f(r), vn) | |
setorder!(vi, vn, get_num_produce(vi)) | |
else | |
# Otherwise we just extract it. | |
r = vi[vn, dist] | |
end | |
else | |
r = init(rng, dist, sampler) | |
if istrans(vi) | |
f = to_linked_internal_transform(vi, vn, dist) | |
push!!(vi, vn, f(r), dist) | |
# By default `push!!` sets the transformed flag to `false`. | |
settrans!!(vi, true, vn) | |
else | |
push!!(vi, vn, r, dist) | |
end | |
end | |
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. | |
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) | |
return r, logpdf(dist, r) - logjac, vi | |
end | |
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) | |
observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) | |
function observe(right::Distribution, left, vi) | |
increment_num_produce!(vi) | |
return Distributions.loglikelihood(right, left), vi | |
end |
This is all kinda hacky and also confusing since SampleFromPrior
doesn't actually sample if haskey(vi, vn)
(but it's hard to see why! - if haskey(vi, vn)
, but sampler isa SampleFromPrior
, the second if doesn't fire and we just get r = vi[vn, dist]
)
It should be cleaned up and the behaviour made more consistent. Also the comment about using Sampler{Prior}
and Sampler{Uniform}
IMO makes a lot of sense. Note, this could have further impacts on Turing.jl because there is some SampleFromPrior
/ SampleFromUniform
type piracy there. Also it might have an effect on the implementation of TuringLang/Turing.jl#2476.