Skip to content

Commit 5a5a4e9

Browse files
committed
Remove SamplingContext for good
1 parent 2ca382e commit 5a5a4e9

16 files changed

+31
-389
lines changed

docs/src/api.md

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -447,12 +447,12 @@ AbstractPPL.evaluate!!
447447

448448
This method mutates the `varinfo` used for execution.
449449
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
450+
If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this.
450451

451452
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
452453
Contexts are subtypes of `AbstractPPL.AbstractContext`.
453454

454455
```@docs
455-
SamplingContext
456456
DefaultContext
457457
PrefixContext
458458
ConditionContext
@@ -486,15 +486,7 @@ DynamicPPL.init
486486

487487
### Samplers
488488

489-
In DynamicPPL two samplers are defined that are used to initialize unobserved random variables:
490-
[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution.
491-
492-
```@docs
493-
SampleFromPrior
494-
SampleFromUniform
495-
```
496-
497-
Additionally, a generic sampler for inference is implemented.
489+
In DynamicPPL a generic sampler for inference is implemented.
498490

499491
```@docs
500492
Sampler

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ else
88
using ..EnzymeCore
99
end
1010

11-
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true
12-
1311
# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
1412
# only checks whether such a method exists, and never runs it.
1513
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing

src/DynamicPPL.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,10 @@ export AbstractVarInfo,
9696
values_as_in_model,
9797
# Samplers
9898
Sampler,
99-
SampleFromPrior,
100-
SampleFromUniform,
10199
# LogDensityFunction
102100
LogDensityFunction,
103101
# Contexts
104102
contextualize,
105-
SamplingContext,
106103
DefaultContext,
107104
PrefixContext,
108105
ConditionContext,

src/context_implementations.jl

Lines changed: 5 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,14 @@
11
# assume
2-
"""
3-
tilde_assume(context::SamplingContext, right, vn, vi)
4-
5-
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6-
accumulate the log probability, and return the sampled value with a context associated
7-
with a sampler.
8-
9-
Falls back to
10-
```julia
11-
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12-
```
13-
"""
14-
function tilde_assume(context::SamplingContext, right, vn, vi)
15-
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
16-
end
17-
182
function tilde_assume(context::AbstractContext, args...)
193
return tilde_assume(childcontext(context), args...)
204
end
215
function tilde_assume(::DefaultContext, right, vn, vi)
22-
return assume(right, vn, vi)
23-
end
24-
25-
function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...)
26-
return tilde_assume(rng, childcontext(context), args...)
27-
end
28-
function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi)
29-
return assume(rng, sampler, right, vn, vi)
30-
end
31-
function tilde_assume(::DefaultContext, sampler, right, vn, vi)
32-
# same as above but no rng
33-
return assume(Random.default_rng(), sampler, right, vn, vi)
6+
y = getindex_internal(vi, vn)
7+
f = from_maybe_linked_internal_transform(vi, vn, right)
8+
x, inv_logjac = with_logabsdet_jacobian(f, y)
9+
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right)
10+
return x, vi
3411
end
35-
3612
function tilde_assume(context::PrefixContext, right, vn, vi)
3713
# Note that we can't use something like this here:
3814
# new_vn = prefix(context, vn)
@@ -46,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
4622
new_vn, new_context = prefix_and_strip_contexts(context, vn)
4723
return tilde_assume(new_context, right, new_vn, vi)
4824
end
49-
function tilde_assume(
50-
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
51-
)
52-
new_vn, new_context = prefix_and_strip_contexts(context, vn)
53-
return tilde_assume(rng, new_context, sampler, right, new_vn, vi)
54-
end
5525

5626
"""
5727
tilde_assume!!(context, right, vn, vi)
@@ -71,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
7141
end
7242

7343
# observe
74-
"""
75-
tilde_observe!!(context::SamplingContext, right, left, vi)
76-
77-
Handle observed constants with a `context` associated with a sampler.
78-
79-
Falls back to `tilde_observe!!(context.context, right, left, vi)`.
80-
"""
81-
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
82-
return tilde_observe!!(context.context, right, left, vn, vi)
83-
end
84-
8544
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
8645
return tilde_observe!!(childcontext(context), right, left, vn, vi)
8746
end
@@ -114,58 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
11473
vi = accumulate_observe!!(vi, right, left, vn)
11574
return left, vi
11675
end
117-
118-
function assume(::Random.AbstractRNG, spl::Sampler, dist)
119-
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
120-
end
121-
122-
# fallback without sampler
123-
function assume(dist::Distribution, vn::VarName, vi)
124-
y = getindex_internal(vi, vn)
125-
f = from_maybe_linked_internal_transform(vi, vn, dist)
126-
x, inv_logjac = with_logabsdet_jacobian(f, y)
127-
vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist)
128-
return x, vi
129-
end
130-
131-
# TODO: Remove this thing.
132-
# SampleFromPrior and SampleFromUniform
133-
function assume(
134-
rng::Random.AbstractRNG,
135-
sampler::Union{SampleFromPrior,SampleFromUniform},
136-
dist::Distribution,
137-
vn::VarName,
138-
vi::VarInfoOrThreadSafeVarInfo,
139-
)
140-
if haskey(vi, vn)
141-
# Always overwrite the parameters with new ones for `SampleFromUniform`.
142-
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
143-
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
144-
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
145-
# if that's okay.
146-
unset_flag!(vi, vn, "del", true)
147-
r = init(rng, dist, sampler)
148-
f = to_maybe_linked_internal_transform(vi, vn, dist)
149-
# TODO(mhauru) This should probably be call a function called setindex_internal!
150-
vi = BangBang.setindex!!(vi, f(r), vn)
151-
else
152-
# Otherwise we just extract it.
153-
r = vi[vn, dist]
154-
end
155-
else
156-
r = init(rng, dist, sampler)
157-
if istrans(vi)
158-
f = to_linked_internal_transform(vi, vn, dist)
159-
vi = push!!(vi, vn, f(r), dist)
160-
# By default `push!!` sets the transformed flag to `false`.
161-
vi = settrans!!(vi, true, vn)
162-
else
163-
vi = push!!(vi, vn, r, dist)
164-
end
165-
end
166-
167-
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
168-
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
169-
vi = accumulate_assume!!(vi, r, logjac, vn, dist)
170-
return r, vi
171-
end

src/contexts.jl

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ effectively updating the child context.
4747
```jldoctest
4848
julia> using DynamicPPL: DynamicTransformationContext
4949
50-
julia> ctx = SamplingContext();
50+
julia> ctx = ConditionContext((; a = 1));
5151
5252
julia> DynamicPPL.childcontext(ctx)
5353
DefaultContext()
@@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right
121121
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
122122

123123
# Contexts
124-
"""
125-
SamplingContext(
126-
[rng::Random.AbstractRNG=Random.default_rng()],
127-
[sampler::AbstractSampler=SampleFromPrior()],
128-
[context::AbstractContext=DefaultContext()],
129-
)
130-
131-
Create a context that allows you to sample parameters with the `sampler` when running the model.
132-
The `context` determines how the returned log density is computed when running the model.
133-
134-
See also: [`DefaultContext`](@ref)
135-
"""
136-
struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext
137-
rng::R
138-
sampler::S
139-
context::C
140-
end
141-
142-
function SamplingContext(
143-
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
144-
)
145-
return SamplingContext(rng, sampler, DefaultContext())
146-
end
147-
148-
function SamplingContext(
149-
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
150-
)
151-
return SamplingContext(Random.default_rng(), sampler, context)
152-
end
153-
154-
function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
155-
return SamplingContext(rng, SampleFromPrior(), context)
156-
end
157-
158-
function SamplingContext(context::AbstractContext)
159-
return SamplingContext(Random.default_rng(), SampleFromPrior(), context)
160-
end
161-
162-
NodeTrait(context::SamplingContext) = IsParent()
163-
childcontext(context::SamplingContext) = context.context
164-
function setchildcontext(parent::SamplingContext, child)
165-
return SamplingContext(parent.rng, parent.sampler, child)
166-
end
167-
168-
"""
169-
hassampler(context)
170-
171-
Return `true` if `context` has a sampler.
172-
"""
173-
hassampler(::SamplingContext) = true
174-
hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context)
175-
hassampler(::IsLeaf, context::AbstractContext) = false
176-
hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context))
177-
178-
"""
179-
getsampler(context)
180-
181-
Return the sampler of the context `context`.
182-
183-
This will traverse the context tree until it reaches the first [`SamplingContext`](@ref),
184-
at which point it will return the sampler of that context.
185-
"""
186-
getsampler(context::SamplingContext) = context.sampler
187-
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
188-
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
189-
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
190-
191124
"""
192125
struct DefaultContext <: AbstractContext end
193126

src/debug_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ and checking if the model is consistent across runs.
485485
function has_static_constraints(
486486
rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false
487487
)
488-
new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior()))
488+
new_model = DynamicPPL.contextualize(model, InitContext(rng))
489489
results = map(1:num_evals) do _
490490
check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure)
491491
end

src/sampler.jl

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,3 @@
1-
# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler`
2-
# That would let us use all defaults for Sampler, combine it with other samplers etc.
3-
"""
4-
SampleFromUniform
5-
6-
Sampling algorithm that samples unobserved random variables from a uniform distribution.
7-
8-
# References
9-
10-
[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values)
11-
"""
12-
struct SampleFromUniform <: AbstractSampler end
13-
14-
"""
15-
SampleFromPrior
16-
17-
Sampling algorithm that samples unobserved random variables from their prior distribution.
18-
"""
19-
struct SampleFromPrior <: AbstractSampler end
20-
21-
# Initializations.
22-
init(rng, dist, ::SampleFromPrior) = rand(rng, dist)
23-
function init(rng, dist, ::SampleFromUniform)
24-
return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist)
25-
end
26-
27-
init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n)
28-
function init(rng, dist, ::SampleFromUniform, n::Int)
29-
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
30-
end
31-
321
# TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`?
332
# (Selector has been removed).
343
"""
@@ -49,20 +18,6 @@ struct Sampler{T} <: AbstractSampler
4918
alg::T
5019
end
5120

52-
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
53-
function AbstractMCMC.step(
54-
rng::Random.AbstractRNG,
55-
model::Model,
56-
sampler::Union{SampleFromUniform,SampleFromPrior},
57-
state=nothing;
58-
kwargs...,
59-
)
60-
vi = VarInfo()
61-
strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform()
62-
_, new_vi = DynamicPPL.init!!(rng, model, vi, strategy)
63-
return new_vi, nothing
64-
end
65-
6621
"""
6722
default_varinfo(rng, model, sampler)
6823

src/simple_varinfo.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -466,25 +466,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
466466
return SimpleVarInfo(values, accs, transformation)
467467
end
468468

469-
# Context implementations
470-
# NOTE: Evaluations, i.e. those without `rng` are shared with other
471-
# implementations of `AbstractVarInfo`.
472-
function assume(
473-
rng::Random.AbstractRNG,
474-
sampler::Union{SampleFromPrior,SampleFromUniform},
475-
dist::Distribution,
476-
vn::VarName,
477-
vi::SimpleOrThreadSafeSimple,
478-
)
479-
value = init(rng, dist, sampler)
480-
# Transform if we're working in unconstrained space.
481-
f = to_maybe_linked_internal_transform(vi, vn, dist)
482-
value_raw, logjac = with_logabsdet_jacobian(f, value)
483-
vi = BangBang.push!!(vi, vn, value_raw, dist)
484-
vi = accumulate_assume!!(vi, value, logjac, vn, dist)
485-
return value, vi
486-
end
487-
488469
function settrans!!(vi::SimpleVarInfo, trans)
489470
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
490471
end

0 commit comments

Comments
 (0)