Skip to content

Commit 9e8eb05

Browse files
committed
WIP: InitContext
1 parent f1d5f20 commit 9e8eb05

12 files changed

+130
-426
lines changed

src/DynamicPPL.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,14 @@ export AbstractVarInfo,
9797
values_as_in_model,
9898
# Samplers
9999
Sampler,
100-
SampleFromPrior,
101-
SampleFromUniform,
100+
# Initialisation strategies
101+
PriorInit,
102+
UniformInit,
103+
ParamsInit,
102104
# LogDensityFunction
103105
LogDensityFunction,
104106
# Contexts
105107
contextualize,
106-
SamplingContext,
107108
DefaultContext,
108109
PrefixContext,
109110
ConditionContext,

src/context_implementations.jl

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,4 @@
11
# assume
2-
"""
3-
tilde_assume(context::SamplingContext, right, vn, vi)
4-
5-
Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6-
accumulate the log probability, and return the sampled value with a context associated
7-
with a sampler.
8-
9-
Falls back to
10-
```julia
11-
tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12-
```
13-
"""
14-
function tilde_assume(context::SamplingContext, right, vn, vi)
15-
return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
16-
end
17-
182
function tilde_assume(context::AbstractContext, args...)
193
return tilde_assume(childcontext(context), args...)
204
end
@@ -71,17 +55,6 @@ function tilde_assume!!(context, right, vn, vi)
7155
end
7256

7357
# observe
74-
"""
75-
tilde_observe!!(context::SamplingContext, right, left, vi)
76-
77-
Handle observed constants with a `context` associated with a sampler.
78-
79-
Falls back to `tilde_observe!!(context.context, right, left, vi)`.
80-
"""
81-
function tilde_observe!!(context::SamplingContext, right, left, vn, vi)
82-
return tilde_observe!!(context.context, right, left, vn, vi)
83-
end
84-
8558
function tilde_observe!!(context::AbstractContext, right, left, vn, vi)
8659
return tilde_observe!!(childcontext(context), right, left, vn, vi)
8760
end
@@ -127,46 +100,3 @@ function assume(dist::Distribution, vn::VarName, vi)
127100
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
128101
return x, vi
129102
end
130-
131-
# TODO: Remove this thing.
132-
# SampleFromPrior and SampleFromUniform
133-
function assume(
134-
rng::Random.AbstractRNG,
135-
sampler::Union{SampleFromPrior,SampleFromUniform},
136-
dist::Distribution,
137-
vn::VarName,
138-
vi::VarInfoOrThreadSafeVarInfo,
139-
)
140-
if haskey(vi, vn)
141-
# Always overwrite the parameters with new ones for `SampleFromUniform`.
142-
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
143-
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
144-
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
145-
# if that's okay.
146-
unset_flag!(vi, vn, "del", true)
147-
r = init(rng, dist, sampler)
148-
f = to_maybe_linked_internal_transform(vi, vn, dist)
149-
# TODO(mhauru) This should probably be call a function called setindex_internal!
150-
vi = BangBang.setindex!!(vi, f(r), vn)
151-
setorder!(vi, vn, get_num_produce(vi))
152-
else
153-
# Otherwise we just extract it.
154-
r = vi[vn, dist]
155-
end
156-
else
157-
r = init(rng, dist, sampler)
158-
if istrans(vi)
159-
f = to_linked_internal_transform(vi, vn, dist)
160-
vi = push!!(vi, vn, f(r), dist)
161-
# By default `push!!` sets the transformed flag to `false`.
162-
vi = settrans!!(vi, true, vn)
163-
else
164-
vi = push!!(vi, vn, r, dist)
165-
end
166-
end
167-
168-
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
169-
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
170-
vi = accumulate_assume!!(vi, r, -logjac, vn, dist)
171-
return r, vi
172-
end

src/contexts.jl

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ effectively updating the child context.
4747
```jldoctest
4848
julia> using DynamicPPL: DynamicTransformationContext
4949
50-
julia> ctx = SamplingContext();
50+
julia> ctx = ConditionContext((; a = 1);
5151
5252
julia> DynamicPPL.childcontext(ctx)
5353
DefaultContext()
@@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right
121121
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
122122

123123
# Contexts
124-
"""
125-
SamplingContext(
126-
[rng::Random.AbstractRNG=Random.default_rng()],
127-
[sampler::AbstractSampler=SampleFromPrior()],
128-
[context::AbstractContext=DefaultContext()],
129-
)
130-
131-
Create a context that allows you to sample parameters with the `sampler` when running the model.
132-
The `context` determines how the returned log density is computed when running the model.
133-
134-
See also: [`DefaultContext`](@ref)
135-
"""
136-
struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext
137-
rng::R
138-
sampler::S
139-
context::C
140-
end
141-
142-
function SamplingContext(
143-
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
144-
)
145-
return SamplingContext(rng, sampler, DefaultContext())
146-
end
147-
148-
function SamplingContext(
149-
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
150-
)
151-
return SamplingContext(Random.default_rng(), sampler, context)
152-
end
153-
154-
function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
155-
return SamplingContext(rng, SampleFromPrior(), context)
156-
end
157-
158-
function SamplingContext(context::AbstractContext)
159-
return SamplingContext(Random.default_rng(), SampleFromPrior(), context)
160-
end
161-
162-
NodeTrait(context::SamplingContext) = IsParent()
163-
childcontext(context::SamplingContext) = context.context
164-
function setchildcontext(parent::SamplingContext, child)
165-
return SamplingContext(parent.rng, parent.sampler, child)
166-
end
167-
168-
"""
169-
hassampler(context)
170-
171-
Return `true` if `context` has a sampler.
172-
"""
173-
hassampler(::SamplingContext) = true
174-
hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context)
175-
hassampler(::IsLeaf, context::AbstractContext) = false
176-
hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context))
177-
178-
"""
179-
getsampler(context)
180-
181-
Return the sampler of the context `context`.
182-
183-
This will traverse the context tree until it reaches the first [`SamplingContext`](@ref),
184-
at which point it will return the sampler of that context.
185-
"""
186-
getsampler(context::SamplingContext) = context.sampler
187-
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
188-
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
189-
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")
190-
191124
"""
192125
struct DefaultContext <: AbstractContext end
193126

src/debug_utils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,8 @@ function check_model_and_trace(
438438
kwargs...,
439439
)
440440
# Execute the model with the debug context.
441-
debug_context = DebugContext(
442-
SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs...
443-
)
441+
new_context = setleafcontext(model.context, InitContext(rng, Prior()))
442+
debug_context = DebugContext(new_context; error_on_failure=error_on_failure, kwargs...)
444443
debug_model = DynamicPPL.contextualize(model, debug_context)
445444

446445
# Perform checks before evaluating the model.

src/extract_priors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function extract_priors(rng::Random.AbstractRNG, model::Model)
116116
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117117
# can't push new variables without knowing the num_produce. Remove this when possible.
118118
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator()))
119-
varinfo = last(evaluate_and_sample!!(rng, model, varinfo))
119+
varinfo = last(init!!(rng, model, varinfo))
120120
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
121121
end
122122

src/model.jl

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ end
850850
# ^ Weird Documenter.jl bug means that we have to write the two above separately
851851
# as it can only detect the `function`-less syntax.
852852
function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo())
853-
return first(evaluate_and_sample!!(rng, model, varinfo))
853+
return first(init!!(rng, model, varinfo))
854854
end
855855

856856
"""
@@ -864,29 +864,35 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
864864
end
865865

866866
"""
867-
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])
868-
869-
Evaluate the `model` with the given `varinfo`, but perform sampling during the
870-
evaluation using the given `sampler` by wrapping the model's context in a
871-
`SamplingContext`.
867+
init!!(
868+
[rng::Random.AbstractRNG, ]
869+
model::Model,
870+
varinfo::AbstractVarInfo,
871+
[init_strategy::AbstractInitStrategy=PriorInit()]
872+
)
872873
873-
If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref).
874+
Evaluate the `model` and replace the values of the model's random variables
875+
in the given `varinfo` with new values, using a specified initialisation strategy.
876+
If the values in `varinfo` are not set, they will be added.
877+
using a specified initialisation strategy. If `init_strategy` is not provided,
878+
defaults to PriorInit().
874879
875880
Returns a tuple of the model's return value, plus the updated `varinfo` object.
876881
"""
877-
function evaluate_and_sample!!(
882+
function init!!(
878883
rng::Random.AbstractRNG,
879884
model::Model,
880885
varinfo::AbstractVarInfo,
881-
sampler::AbstractSampler=SampleFromPrior(),
886+
init_strategy::AbstractInitStrategy=PriorInit(),
882887
)
883-
sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context))
884-
return evaluate!!(sampling_model, varinfo)
888+
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
889+
new_model = contextualize(model, new_context)
890+
return evaluate!!(new_model, varinfo)
885891
end
886-
function evaluate_and_sample!!(
887-
model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior()
892+
function init!!(
893+
model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit()
888894
)
889-
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler)
895+
return init!!(Random.default_rng(), model, varinfo, init_strategy)
890896
end
891897

892898
"""
@@ -1049,11 +1055,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
10491055
Generate a sample of type `T` from the prior distribution of the `model`.
10501056
"""
10511057
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
1052-
x = last(
1053-
evaluate_and_sample!!(
1054-
rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())
1055-
),
1056-
)
1058+
x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())))
10571059
return values_as(x, T)
10581060
end
10591061

@@ -1280,3 +1282,38 @@ end
12801282
function returned(model::Model, values, keys)
12811283
return returned(model, NamedTuple{keys}(values))
12821284
end
1285+
1286+
"""
1287+
prefix(model::Model, x::VarName)
1288+
prefix(model::Model, x::Val{sym})
1289+
prefix(model::Model, x::Any)
1290+
1291+
Return `model` but with all random variables prefixed by `x`, where `x` is either:
1292+
- a `VarName` (e.g. `@varname(a)`),
1293+
- a `Val{sym}` (e.g. `Val(:a)`), or
1294+
- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that
1295+
this will introduce runtime overheads so is not recommended unless absolutely
1296+
necessary.
1297+
1298+
# Examples
1299+
1300+
```jldoctest
1301+
julia> using DynamicPPL: prefix
1302+
1303+
julia> @model demo() = x ~ Dirac(1)
1304+
demo (generic function with 2 methods)
1305+
1306+
julia> rand(prefix(demo(), @varname(my_prefix)))
1307+
(var"my_prefix.x" = 1,)
1308+
1309+
julia> rand(prefix(demo(), Val(:my_prefix)))
1310+
(var"my_prefix.x" = 1,)
1311+
```
1312+
"""
1313+
prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context))
1314+
function prefix(model::Model, x::Val{sym}) where {sym}
1315+
return contextualize(model, PrefixContext(VarName{sym}(), model.context))
1316+
end
1317+
function prefix(model::Model, x)
1318+
return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context))
1319+
end

0 commit comments

Comments
 (0)