Skip to content

Replace PrefixContext with model.prefix #1011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: py/remove-samplingcontext
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa

A core component of DynamicPPL is the [`@model`](@ref) macro.
It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements.
These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities.
These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities.

```@docs
@model
Expand Down Expand Up @@ -344,6 +344,13 @@ Base.empty!
SimpleVarInfo
```

### Tilde-pipeline

```@docs
tilde_assume!!
tilde_observe!!
```

### Accumulators

The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.
Expand Down Expand Up @@ -450,33 +457,45 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:

```@docs
DynamicPPL.evaluate_and_sample!!
```
If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this.

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
SamplingContext
DefaultContext
PrefixContext
ConditionContext
InitContext
```

### Samplers
### VarInfo initialisation

The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.

```@docs
DynamicPPL.init!!
```

In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
There are three concrete strategies provided in DynamicPPL:

```@docs
SampleFromPrior
SampleFromUniform
PriorInit
UniformInit
ParamsInit
```

Additionally, a generic sampler for inference is implemented.
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.

```@docs
DynamicPPL.AbstractInitStrategy
DynamicPPL.init
```

### Samplers

In DynamicPPL a generic sampler for inference is implemented.

```@docs
Sampler
Expand All @@ -487,7 +506,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
DynamicPPL.init_strategy
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Expand All @@ -502,9 +521,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
tilde_assume
```
2 changes: 0 additions & 2 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ else
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) =
Expand Down
15 changes: 5 additions & 10 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,17 @@ end
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(sampling_model)
varinfo = DynamicPPL.typed_varinfo(model)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
# Let's make sure that evaluation doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
sampling_model, varinfo; only_ddpl
model, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug "Evaluation with typed varinfo failed with the following issues:"
@debug result
end

Expand All @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(sampling_model)
DynamicPPL.untyped_varinfo(model)
end
end

Expand Down
38 changes: 27 additions & 11 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
Expand All @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx)
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Any}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -114,9 +123,15 @@ function DynamicPPL.predict(

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
Expand Down Expand Up @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`, and
# return the model's retval.
retval, _ = DynamicPPL.init!!(
model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit())
)
retval
end
end

Expand Down
19 changes: 12 additions & 7 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,21 @@ export AbstractVarInfo,
values_as_in_model,
# Samplers
Sampler,
SampleFromPrior,
SampleFromUniform,
# LogDensityFunction
LogDensityFunction,
# Contexts
contextualize,
SamplingContext,
DefaultContext,
PrefixContext,
ConditionContext,
assume,
tilde_assume,
# Tilde pipeline
tilde_assume!!,
tilde_observe!!,
# Initialisation
InitContext,
AbstractInitStrategy,
PriorInit,
UniformInit,
ParamsInit,
# Pseudo distributions
NamedDist,
NoDist,
Expand Down Expand Up @@ -170,11 +173,13 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
# Necessary forward declarations
include("utils.jl")
include("chains.jl")
include("contexts.jl")
include("contexts/init.jl")
include("model.jl")
include("prefix.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("submodel.jl")
include("varnamedvector.jl")
include("accumulators.jl")
Expand Down
47 changes: 25 additions & 22 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ evaluates to a `VarName`, and this will be used in the subsequent checks.
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
used in its place.
"""
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(expr))
@gensym vn
return quote
if $(DynamicPPL.contextual_isassumption)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like
# the whole `isassumption` thing to be simplified, though, so I'll
# leave it till later.
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn)
# Considered an assumption by `__model__.context` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
Expand All @@ -78,8 +81,8 @@ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
if !($(DynamicPPL.inargnames)($left_vn, __model__)) ||
$(DynamicPPL.inmissings)($left_vn, __model__)
true
else
$(maybe_view(expr)) === missing
Expand All @@ -99,7 +102,7 @@ isassumption(expr) = :(false)

Return `true` if `vn` is considered an assumption by `context`.
"""
function contextual_isassumption(context::AbstractContext, vn)
function contextual_isassumption(context::AbstractContext, vn::VarName)
if hasconditioned_nested(context, vn)
val = getconditioned_nested(context, vn)
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
Expand All @@ -115,9 +118,7 @@ end

isfixed(expr, vn) = false
function isfixed(::Union{Symbol,Expr}, vn)
return :($(DynamicPPL.contextual_isfixed)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
))
return :($(DynamicPPL.contextual_isfixed)(__model__.context, $vn))
end

"""
Expand Down Expand Up @@ -413,7 +414,9 @@ function generate_assign(left, right)
return quote
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
$vn = $(DynamicPPL.maybe_prefix)(
$(make_varname_expression(left)), __model__.prefix
)
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
)
Expand Down Expand Up @@ -448,24 +451,23 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption value dist
@gensym left_vn vn isassumption value dist

return quote
$dist = $right
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$isassumption = $(DynamicPPL.isassumption(left, vn))
$left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
$isassumption = $(DynamicPPL.isassumption(left, left_vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
$left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# If `left_vn` is not in `argnames`, we need to make sure that the variable is defined.
# (Note: we use the unprefixed `left_vn` here rather than `vn` which will have had
# prefixes applied!)
if !$(DynamicPPL.inargnames)($left_vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(__model__.context, $vn)
end

$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
Expand Down Expand Up @@ -495,6 +497,7 @@ function generate_tilde_assume(left, right, vn)
return quote
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
__model__.context,
__model__.prefix,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
)
Expand Down
Loading
Loading