Skip to content

Commit 991e825

Browse files
penelopeysmmhauru
andauthored
InitContext, part 3 - Introduce InitContext (#981)
* Implement InitContext * Fix loading order of modules; move `prefix(::Model)` to model.jl * Add tests for InitContext behaviour * inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). * Document * Add a test to check that `init!!` doesn't change linking * Fix `push!` for VarNamedVector This should have been changed in #940, but slipped through as the file wasn't listed as one of the changed files. * Add some line breaks Co-authored-by: Markus Hauru <[email protected]> * Add the option of no fallback for ParamsInit * Improve docstrings * typo * `p.default` -> `p.fallback` * Rename `{Prior,Uniform,Params}Init` -> `InitFrom{Prior,Uniform,Params}` --------- Co-authored-by: Markus Hauru <[email protected]>
1 parent 7b55aa3 commit 991e825

File tree

7 files changed

+547
-40
lines changed

7 files changed

+547
-40
lines changed

docs/src/api.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,27 @@ SamplingContext
463463
DefaultContext
464464
PrefixContext
465465
ConditionContext
466+
InitContext
467+
```
468+
469+
### VarInfo initialisation
470+
471+
`InitContext` is used to initialise, or overwrite, values in a VarInfo.
472+
473+
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
474+
There are three concrete strategies provided in DynamicPPL:
475+
476+
```@docs
477+
InitFromPrior
478+
InitFromUniform
479+
InitFromParams
480+
```
481+
482+
If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.
483+
484+
```@docs
485+
DynamicPPL.AbstractInitStrategy
486+
DynamicPPL.init
466487
```
467488

468489
### Samplers

src/DynamicPPL.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ export AbstractVarInfo,
108108
ConditionContext,
109109
assume,
110110
tilde_assume,
111+
# Initialisation
112+
InitContext,
113+
AbstractInitStrategy,
114+
InitFromPrior,
115+
InitFromUniform,
116+
InitFromParams,
111117
# Pseudo distributions
112118
NamedDist,
113119
NoDist,
@@ -169,11 +175,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
169175
# Necessary forward declarations
170176
include("utils.jl")
171177
include("chains.jl")
178+
include("contexts.jl")
179+
include("contexts/init.jl")
172180
include("model.jl")
173181
include("sampler.jl")
174182
include("varname.jl")
175183
include("distribution_wrappers.jl")
176-
include("contexts.jl")
177184
include("submodel.jl")
178185
include("varnamedvector.jl")
179186
include("accumulators.jl")

src/contexts.jl

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName
280280
return vn, setchildcontext(ctx, new_ctx)
281281
end
282282

283-
"""
284-
prefix(model::Model, x::VarName)
285-
prefix(model::Model, x::Val{sym})
286-
prefix(model::Model, x::Any)
287-
288-
Return `model` but with all random variables prefixed by `x`, where `x` is either:
289-
- a `VarName` (e.g. `@varname(a)`),
290-
- a `Val{sym}` (e.g. `Val(:a)`), or
291-
- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that
292-
this will introduce runtime overheads so is not recommended unless absolutely
293-
necessary.
294-
295-
# Examples
296-
297-
```jldoctest
298-
julia> using DynamicPPL: prefix
299-
300-
julia> @model demo() = x ~ Dirac(1)
301-
demo (generic function with 2 methods)
302-
303-
julia> rand(prefix(demo(), @varname(my_prefix)))
304-
(var"my_prefix.x" = 1,)
305-
306-
julia> rand(prefix(demo(), Val(:my_prefix)))
307-
(var"my_prefix.x" = 1,)
308-
```
309-
"""
310-
prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context))
311-
function prefix(model::Model, x::Val{sym}) where {sym}
312-
return contextualize(model, PrefixContext(VarName{sym}(), model.context))
313-
end
314-
function prefix(model::Model, x)
315-
return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context))
316-
end
317-
318283
"""
319284
320285
ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext}

src/contexts/init.jl

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
AbstractInitStrategy
3+
4+
Abstract type representing the possible ways of initialising new values for
5+
the random variables in a model (e.g., when creating a new VarInfo).
6+
7+
Any subtype of `AbstractInitStrategy` must implement the
8+
[`DynamicPPL.init`](@ref) method.
9+
"""
10+
abstract type AbstractInitStrategy end
11+
12+
"""
13+
init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy)
14+
15+
Generate a new value for a random variable with the given distribution.
16+
17+
!!! warning "Return values must be unlinked"
18+
The values returned by `init` must always be in the untransformed space, i.e.,
19+
they must be within the support of the original distribution. That means that,
20+
for example, `init(rng, dist, u::InitFromUniform)` will in general return values that
21+
are outside the range [u.lower, u.upper].
22+
"""
23+
function init end
24+
25+
"""
26+
InitFromPrior()
27+
28+
Obtain new values by sampling from the prior distribution.
29+
"""
30+
struct InitFromPrior <: AbstractInitStrategy end
31+
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior)
32+
return rand(rng, dist)
33+
end
34+
35+
"""
36+
InitFromUniform()
37+
InitFromUniform(lower, upper)
38+
39+
Obtain new values by first transforming the distribution of the random variable
40+
to unconstrained space, then sampling a value uniformly between `lower` and
41+
`upper`, and transforming that value back to the original space.
42+
43+
If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics
44+
Stan's default initialisation strategy.
45+
46+
Requires that `lower <= upper`.
47+
48+
# References
49+
50+
[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization)
51+
"""
52+
struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy
53+
lower::T
54+
upper::T
55+
function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat}
56+
lower > upper &&
57+
throw(ArgumentError("`lower` must be less than or equal to `upper`"))
58+
return new{T}(lower, upper)
59+
end
60+
InitFromUniform() = InitFromUniform(-2.0, 2.0)
61+
end
62+
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform)
63+
b = Bijectors.bijector(dist)
64+
sz = Bijectors.output_size(b, size(dist))
65+
y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...))
66+
b_inv = Bijectors.inverse(b)
67+
x = b_inv(y)
68+
# 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398
69+
if x isa Array{<:Any,0}
70+
x = x[]
71+
end
72+
return x
73+
end
74+
75+
"""
76+
InitFromParams(
77+
params::Union{AbstractDict{<:VarName},NamedTuple},
78+
fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
79+
)
80+
81+
Obtain new values by extracting them from the given dictionary or NamedTuple.
82+
83+
The parameter `fallback` specifies how new values are to be obtained if they
84+
cannot be found in `params`, or they are specified as `missing`. `fallback`
85+
can either be an initialisation strategy itself, in which case it will be
86+
used to obtain new values, or it can be `nothing`, in which case an error
87+
will be thrown. The default for `fallback` is `InitFromPrior()`.
88+
89+
!!! note
90+
The values in `params` must be provided in the space of the untransformed
91+
distribution.
92+
"""
93+
struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy
94+
params::P
95+
fallback::S
96+
function InitFromParams(
97+
params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing}
98+
)
99+
return new{typeof(params),typeof(fallback)}(params, fallback)
100+
end
101+
function InitFromParams(params::AbstractDict{<:VarName})
102+
return InitFromParams(params, InitFromPrior())
103+
end
104+
function InitFromParams(
105+
params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior()
106+
)
107+
return InitFromParams(to_varname_dict(params), fallback)
108+
end
109+
end
110+
function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams)
111+
# TODO(penelopeysm): It would be nice to do a check to make sure that all
112+
# of the parameters in `p.params` were actually used, and either warn or
113+
# error if they aren't. This is actually quite non-trivial though because
114+
# the structure of Dicts in particular can have arbitrary nesting.
115+
return if hasvalue(p.params, vn, dist)
116+
x = getvalue(p.params, vn, dist)
117+
if x === missing
118+
p.fallback === nothing &&
119+
error("A `missing` value was provided for the variable `$(vn)`.")
120+
init(rng, vn, dist, p.fallback)
121+
else
122+
# TODO(penelopeysm): Since x is user-supplied, maybe we could also
123+
# check here that the type / size of x matches the dist?
124+
x
125+
end
126+
else
127+
p.fallback === nothing && error("No value was provided for the variable `$(vn)`.")
128+
init(rng, vn, dist, p.fallback)
129+
end
130+
end
131+
132+
"""
133+
InitContext(
134+
[rng::Random.AbstractRNG=Random.default_rng()],
135+
[strategy::AbstractInitStrategy=InitFromPrior()],
136+
)
137+
138+
A leaf context that indicates that new values for random variables are
139+
currently being obtained through sampling. Used e.g. when initialising a fresh
140+
VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then
141+
`evaluate!!(model, varinfo)` will override all values in the VarInfo.
142+
"""
143+
struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext
144+
rng::R
145+
strategy::S
146+
function InitContext(
147+
rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior()
148+
)
149+
return new{typeof(rng),typeof(strategy)}(rng, strategy)
150+
end
151+
function InitContext(strategy::AbstractInitStrategy=InitFromPrior())
152+
return InitContext(Random.default_rng(), strategy)
153+
end
154+
end
155+
NodeTrait(::InitContext) = IsLeaf()
156+
157+
function tilde_assume(
158+
ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
159+
)
160+
in_varinfo = haskey(vi, vn)
161+
# `init()` always returns values in original space, i.e. possibly
162+
# constrained
163+
x = init(ctx.rng, vn, dist, ctx.strategy)
164+
# Determine whether to insert a transformed value into the VarInfo.
165+
# If the VarInfo alrady had a value for this variable, we will
166+
# keep the same linked status as in the original VarInfo. If not, we
167+
# check the rest of the VarInfo to see if other variables are linked.
168+
# istrans(vi) returns true if vi is nonempty and all variables in vi
169+
# are linked.
170+
insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi)
171+
f = if insert_transformed_value
172+
link_transform(dist)
173+
else
174+
identity
175+
end
176+
y, logjac = with_logabsdet_jacobian(f, x)
177+
# Add the new value to the VarInfo. `push!!` errors if the value already
178+
# exists, hence the need for setindex!!.
179+
if in_varinfo
180+
vi = setindex!!(vi, y, vn)
181+
else
182+
vi = push!!(vi, vn, y, dist)
183+
end
184+
# Neither of these set the `trans` flag so we have to do it manually if
185+
# necessary.
186+
insert_transformed_value && settrans!!(vi, true, vn)
187+
# `accumulate_assume!!` wants untransformed values as the second argument.
188+
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
189+
# We always return the untransformed value here, as that will determine
190+
# what the lhs of the tilde-statement is set to.
191+
return x, vi
192+
end
193+
194+
function tilde_observe!!(::InitContext, right, left, vn, vi)
195+
return tilde_observe!!(DefaultContext(), right, left, vn, vi)
196+
end

src/model.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled.
799799
"""
800800
fixed(model::Model) = fixed(model.context)
801801

802+
"""
803+
prefix(model::Model, x::VarName)
804+
prefix(model::Model, x::Val{sym})
805+
prefix(model::Model, x::Any)
806+
807+
Return `model` but with all random variables prefixed by `x`, where `x` is either:
808+
- a `VarName` (e.g. `@varname(a)`),
809+
- a `Val{sym}` (e.g. `Val(:a)`), or
810+
- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that
811+
this will introduce runtime overheads so is not recommended unless absolutely
812+
necessary.
813+
814+
# Examples
815+
816+
```jldoctest
817+
julia> using DynamicPPL: prefix
818+
819+
julia> @model demo() = x ~ Dirac(1)
820+
demo (generic function with 2 methods)
821+
822+
julia> rand(prefix(demo(), @varname(my_prefix)))
823+
(var"my_prefix.x" = 1,)
824+
825+
julia> rand(prefix(demo(), Val(:my_prefix)))
826+
(var"my_prefix.x" = 1,)
827+
```
828+
"""
829+
prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context))
830+
function prefix(model::Model, x::Val{sym}) where {sym}
831+
return contextualize(model, PrefixContext(VarName{sym}(), model.context))
832+
end
833+
function prefix(model::Model, x)
834+
return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context))
835+
end
836+
802837
"""
803838
(model::Model)([rng, varinfo])
804839
@@ -854,6 +889,41 @@ function evaluate_and_sample!!(
854889
return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler)
855890
end
856891

892+
"""
893+
init!!(
894+
[rng::Random.AbstractRNG,]
895+
model::Model,
896+
varinfo::AbstractVarInfo,
897+
[init_strategy::AbstractInitStrategy=InitFromPrior()]
898+
)
899+
900+
Evaluate the `model` and replace the values of the model's random variables in
901+
the given `varinfo` with new values using a specified initialisation strategy.
902+
If the values in `varinfo` are not already present, they will be added using
903+
that same strategy.
904+
905+
If `init_strategy` is not provided, defaults to InitFromPrior().
906+
907+
Returns a tuple of the model's return value, plus the updated `varinfo` object.
908+
"""
909+
function init!!(
910+
rng::Random.AbstractRNG,
911+
model::Model,
912+
varinfo::AbstractVarInfo,
913+
init_strategy::AbstractInitStrategy=InitFromPrior(),
914+
)
915+
new_context = setleafcontext(model.context, InitContext(rng, init_strategy))
916+
new_model = contextualize(model, new_context)
917+
return evaluate!!(new_model, varinfo)
918+
end
919+
function init!!(
920+
model::Model,
921+
varinfo::AbstractVarInfo,
922+
init_strategy::AbstractInitStrategy=InitFromPrior(),
923+
)
924+
return init!!(Random.default_rng(), model, varinfo, init_strategy)
925+
end
926+
857927
"""
858928
evaluate!!(model::Model, varinfo)
859929

src/varnamedvector.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,11 @@ function update_internal!(
766766
return nothing
767767
end
768768

769+
function BangBang.push!(vnv::VarNamedVector, vn, val, dist)
770+
f = from_vec_transform(dist)
771+
return setindex_internal!(vnv, tovec(val), vn, f)
772+
end
773+
769774
# BangBang versions of the above functions.
770775
# The only difference is that update_internal!! and insert_internal!! check whether the
771776
# container types of the VarNamedVector vector need to be expanded to accommodate the new

0 commit comments

Comments
 (0)