-
Notifications
You must be signed in to change notification settings - Fork 36
InitContext
, part 5 - Remove SamplingContext
, SampleFrom{Prior,Uniform}
, {tilde_,}assume
#985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
5a5a4e9
dde8b7e
d7c4033
6c776e9
992569f
6974cc1
aaece0b
77a8710
f4e5f4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry this file's diff is a bit of a mess, there are no real code changes, it's just:
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,93 +1,102 @@ | ||
# assume | ||
""" | ||
tilde_assume(context::SamplingContext, right, vn, vi) | ||
|
||
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), | ||
accumulate the log probability, and return the sampled value with a context associated | ||
with a sampler. | ||
|
||
Falls back to | ||
```julia | ||
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) | ||
``` | ||
DynamicPPL.tilde_assume!!( | ||
context::AbstractContext, | ||
right::Distribution, | ||
vn::VarName, | ||
vi::AbstractVarInfo | ||
) | ||
|
||
Handle assumed variables, i.e. anything which is not observed (see | ||
[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the | ||
sampled value and updated `vi`. | ||
|
||
`vn` is the VarName on the left-hand side of the tilde statement. | ||
""" | ||
function tilde_assume(context::SamplingContext, right, vn, vi) | ||
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) | ||
end | ||
|
||
function tilde_assume(context::AbstractContext, args...) | ||
return tilde_assume(childcontext(context), args...) | ||
end | ||
function tilde_assume(::DefaultContext, right, vn, vi) | ||
return assume(right, vn, vi) | ||
end | ||
|
||
function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) | ||
return tilde_assume(rng, childcontext(context), args...) | ||
end | ||
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) | ||
return assume(rng, sampler, right, vn, vi) | ||
function tilde_assume!!( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To confirm that I understand how this will play out in Turing.jl: The idea is that samplers don't need to modify the behaviour of the tilde pipeline any more, and thus There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so, although I'm not 100% sure. We'll probably have to do a similar process to what we did last time for accumulators: make a Turing PR that builds against the breaking branch of DPPL. Things might break. But I'm hoping it won't be too bad haha. |
||
context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo | ||
) | ||
return tilde_assume!!(childcontext(context), right, vn, vi) | ||
end | ||
function tilde_assume(::DefaultContext, sampler, right, vn, vi) | ||
# same as above but no rng | ||
return assume(Random.default_rng(), sampler, right, vn, vi) | ||
function tilde_assume!!( | ||
::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo | ||
) | ||
y = getindex_internal(vi, vn) | ||
f = from_maybe_linked_internal_transform(vi, vn, right) | ||
x, inv_logjac = with_logabsdet_jacobian(f, y) | ||
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) | ||
return x, vi | ||
end | ||
|
||
function tilde_assume(context::PrefixContext, right, vn, vi) | ||
function tilde_assume!!( | ||
context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo | ||
) | ||
# Note that we can't use something like this here: | ||
# new_vn = prefix(context, vn) | ||
# return tilde_assume(childcontext(context), right, new_vn, vi) | ||
# return tilde_assume!!(childcontext(context), right, new_vn, vi) | ||
# This is because `prefix` applies _all_ prefixes in a given context to a | ||
# variable name. Thus, if we had two levels of nested prefixes e.g. | ||
# `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the | ||
# first call would apply the prefix `a.b._`, and the recursive call | ||
# would apply the prefix `b._`, resulting in `b.a.b._`. | ||
# This is why we need a special function, `prefix_and_strip_contexts`. | ||
new_vn, new_context = prefix_and_strip_contexts(context, vn) | ||
return tilde_assume(new_context, right, new_vn, vi) | ||
return tilde_assume!!(new_context, right, new_vn, vi) | ||
end | ||
function tilde_assume( | ||
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi | ||
""" | ||
DynamicPPL.tilde_assume!!( | ||
context::AbstractContext, | ||
right::DynamicPPL.Submodel, | ||
vn::VarName, | ||
vi::AbstractVarInfo | ||
) | ||
|
||
Evaluate the submodel with the given context. | ||
""" | ||
function tilde_assume!!( | ||
context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo | ||
) | ||
new_vn, new_context = prefix_and_strip_contexts(context, vn) | ||
return tilde_assume(rng, new_context, sampler, right, new_vn, vi) | ||
return _evaluate!!(right, vi, context, vn) | ||
end | ||
|
||
""" | ||
tilde_assume!!(context, right, vn, vi) | ||
tilde_observe!!( | ||
context::AbstractContext, | ||
right::Distribution, | ||
left, | ||
vn::Union{VarName, Nothing}, | ||
vi::AbstractVarInfo | ||
) | ||
|
||
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), | ||
accumulate the log probability, and return the sampled value and updated `vi`. | ||
This function handles observed variables, which may be: | ||
|
||
By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log | ||
probability of `vi` with the returned value. | ||
""" | ||
function tilde_assume!!(context, right, vn, vi) | ||
return if right isa DynamicPPL.Submodel | ||
_evaluate!!(right, vi, context, vn) | ||
else | ||
tilde_assume(context, right, vn, vi) | ||
end | ||
end | ||
- literals on the left-hand side, e.g., `3.0 ~ Normal()` | ||
- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` | ||
- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. | ||
|
||
# observe | ||
""" | ||
tilde_observe!!(context::SamplingContext, right, left, vi) | ||
The relevant log-probability associated with the observation is computed and accumulated in | ||
the VarInfo object `vi` (except for fixed variables, which do not contribute to the | ||
log-probability). | ||
|
||
Handle observed constants with a `context` associated with a sampler. | ||
`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the | ||
left-hand side, or `nothing` if the left-hand side is a literal value. | ||
|
||
Falls back to `tilde_observe!!(context.context, right, left, vi)`. | ||
Observations of submodels are not yet supported in DynamicPPL. | ||
""" | ||
function tilde_observe!!(context::SamplingContext, right, left, vn, vi) | ||
return tilde_observe!!(context.context, right, left, vn, vi) | ||
end | ||
|
||
function tilde_observe!!(context::AbstractContext, right, left, vn, vi) | ||
function tilde_observe!!( | ||
context::AbstractContext, | ||
right::Distribution, | ||
left, | ||
vn::Union{VarName,Nothing}, | ||
vi::AbstractVarInfo, | ||
) | ||
return tilde_observe!!(childcontext(context), right, left, vn, vi) | ||
end | ||
|
||
# `PrefixContext` | ||
function tilde_observe!!(context::PrefixContext, right, left, vn, vi) | ||
function tilde_observe!!( | ||
context::PrefixContext, | ||
right::Distribution, | ||
left, | ||
vn::Union{VarName,Nothing}, | ||
vi::AbstractVarInfo, | ||
) | ||
# In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal | ||
# value. For the need for prefix_and_strip_contexts rather than just prefix, see the | ||
# comment in `tilde_assume!!`. | ||
|
@@ -98,74 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi) | |
end | ||
return tilde_observe!!(new_context, right, left, new_vn, vi) | ||
end | ||
|
||
""" | ||
tilde_observe!!(context, right, left, vn, vi) | ||
|
||
Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), | ||
accumulate the log probability, and return the observed value and updated `vi`. | ||
|
||
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name | ||
and indices; if needed, these can be accessed through this function, though. | ||
""" | ||
function tilde_observe!!(::DefaultContext, right, left, vn, vi) | ||
right isa DynamicPPL.Submodel && | ||
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) | ||
function tilde_observe!!( | ||
::DefaultContext, | ||
right::Distribution, | ||
left, | ||
vn::Union{VarName,Nothing}, | ||
vi::AbstractVarInfo, | ||
) | ||
vi = accumulate_observe!!(vi, right, left, vn) | ||
return left, vi | ||
end | ||
|
||
function assume(::Random.AbstractRNG, spl::Sampler, dist) | ||
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") | ||
end | ||
|
||
# fallback without sampler | ||
function assume(dist::Distribution, vn::VarName, vi) | ||
y = getindex_internal(vi, vn) | ||
f = from_maybe_linked_internal_transform(vi, vn, dist) | ||
x, inv_logjac = with_logabsdet_jacobian(f, y) | ||
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) | ||
return x, vi | ||
end | ||
|
||
# TODO: Remove this thing. | ||
# SampleFromPrior and SampleFromUniform | ||
function assume( | ||
rng::Random.AbstractRNG, | ||
sampler::Union{SampleFromPrior,SampleFromUniform}, | ||
dist::Distribution, | ||
vn::VarName, | ||
vi::VarInfoOrThreadSafeVarInfo, | ||
function tilde_observe!!( | ||
::AbstractContext, | ||
::DynamicPPL.Submodel, | ||
left, | ||
vn::Union{VarName,Nothing}, | ||
::AbstractVarInfo, | ||
) | ||
if haskey(vi, vn) | ||
# Always overwrite the parameters with new ones for `SampleFromUniform`. | ||
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") | ||
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us | ||
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure | ||
# if that's okay. | ||
unset_flag!(vi, vn, "del", true) | ||
r = init(rng, dist, sampler) | ||
f = to_maybe_linked_internal_transform(vi, vn, dist) | ||
# TODO(mhauru) This should probably be call a function called setindex_internal! | ||
vi = BangBang.setindex!!(vi, f(r), vn) | ||
else | ||
# Otherwise we just extract it. | ||
r = vi[vn, dist] | ||
end | ||
else | ||
r = init(rng, dist, sampler) | ||
if istrans(vi) | ||
f = to_linked_internal_transform(vi, vn, dist) | ||
vi = push!!(vi, vn, f(r), dist) | ||
# By default `push!!` sets the transformed flag to `false`. | ||
vi = settrans!!(vi, true, vn) | ||
else | ||
vi = push!!(vi, vn, r, dist) | ||
end | ||
end | ||
|
||
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. | ||
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) | ||
vi = accumulate_assume!!(vi, r, logjac, vn, dist) | ||
return r, vi | ||
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it time to start a HISTORY.md entry? Might be easier to do it here were you can cross-check against what's being removed, rather than once everything is in breaking in a huge diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good idea to do it in this PR. I'll write one up later and ping you again