|
| 1 | +""" |
| 2 | + AbstractInitStrategy |
| 3 | +
|
| 4 | +Abstract type representing the possible ways of initialising new values for |
| 5 | +the random variables in a model (e.g., when creating a new VarInfo). |
| 6 | +
|
| 7 | +Any subtype of `AbstractInitStrategy` must implement the |
| 8 | +[`DynamicPPL.init`](@ref) method. |
| 9 | +""" |
| 10 | +abstract type AbstractInitStrategy end |
| 11 | + |
| 12 | +""" |
| 13 | + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) |
| 14 | +
|
| 15 | +Generate a new value for a random variable with the given distribution. |
| 16 | +
|
| 17 | +!!! warning "Return values must be unlinked" |
| 18 | + The values returned by `init` must always be in the untransformed space, i.e., |
| 19 | + they must be within the support of the original distribution. That means that, |
| 20 | + for example, `init(rng, dist, u::InitFromUniform)` will in general return values that |
| 21 | + are outside the range [u.lower, u.upper]. |
| 22 | +""" |
| 23 | +function init end |
| 24 | + |
| 25 | +""" |
| 26 | + InitFromPrior() |
| 27 | +
|
| 28 | +Obtain new values by sampling from the prior distribution. |
| 29 | +""" |
| 30 | +struct InitFromPrior <: AbstractInitStrategy end |
| 31 | +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) |
| 32 | + return rand(rng, dist) |
| 33 | +end |
| 34 | + |
| 35 | +""" |
| 36 | + InitFromUniform() |
| 37 | + InitFromUniform(lower, upper) |
| 38 | +
|
| 39 | +Obtain new values by first transforming the distribution of the random variable |
| 40 | +to unconstrained space, then sampling a value uniformly between `lower` and |
| 41 | +`upper`, and transforming that value back to the original space. |
| 42 | +
|
| 43 | +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics |
| 44 | +Stan's default initialisation strategy. |
| 45 | +
|
| 46 | +Requires that `lower <= upper`. |
| 47 | +
|
| 48 | +# References |
| 49 | +
|
| 50 | +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) |
| 51 | +""" |
| 52 | +struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy |
| 53 | + lower::T |
| 54 | + upper::T |
| 55 | + function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat} |
| 56 | + lower > upper && |
| 57 | + throw(ArgumentError("`lower` must be less than or equal to `upper`")) |
| 58 | + return new{T}(lower, upper) |
| 59 | + end |
| 60 | + InitFromUniform() = InitFromUniform(-2.0, 2.0) |
| 61 | +end |
| 62 | +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) |
| 63 | + b = Bijectors.bijector(dist) |
| 64 | + sz = Bijectors.output_size(b, size(dist)) |
| 65 | + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) |
| 66 | + b_inv = Bijectors.inverse(b) |
| 67 | + x = b_inv(y) |
| 68 | + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 |
| 69 | + if x isa Array{<:Any,0} |
| 70 | + x = x[] |
| 71 | + end |
| 72 | + return x |
| 73 | +end |
| 74 | + |
| 75 | +""" |
| 76 | + InitFromParams( |
| 77 | + params::Union{AbstractDict{<:VarName},NamedTuple}, |
| 78 | + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() |
| 79 | + ) |
| 80 | +
|
| 81 | +Obtain new values by extracting them from the given dictionary or NamedTuple. |
| 82 | +
|
| 83 | +The parameter `fallback` specifies how new values are to be obtained if they |
| 84 | +cannot be found in `params`, or they are specified as `missing`. `fallback` |
| 85 | +can either be an initialisation strategy itself, in which case it will be |
| 86 | +used to obtain new values, or it can be `nothing`, in which case an error |
| 87 | +will be thrown. The default for `fallback` is `InitFromPrior()`. |
| 88 | +
|
| 89 | +!!! note |
| 90 | + The values in `params` must be provided in the space of the untransformed |
| 91 | + distribution. |
| 92 | +""" |
| 93 | +struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy |
| 94 | + params::P |
| 95 | + fallback::S |
| 96 | + function InitFromParams( |
| 97 | + params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing} |
| 98 | + ) |
| 99 | + return new{typeof(params),typeof(fallback)}(params, fallback) |
| 100 | + end |
| 101 | + function InitFromParams(params::AbstractDict{<:VarName}) |
| 102 | + return InitFromParams(params, InitFromPrior()) |
| 103 | + end |
| 104 | + function InitFromParams( |
| 105 | + params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() |
| 106 | + ) |
| 107 | + return InitFromParams(to_varname_dict(params), fallback) |
| 108 | + end |
| 109 | +end |
| 110 | +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) |
| 111 | + # TODO(penelopeysm): It would be nice to do a check to make sure that all |
| 112 | + # of the parameters in `p.params` were actually used, and either warn or |
| 113 | + # error if they aren't. This is actually quite non-trivial though because |
| 114 | + # the structure of Dicts in particular can have arbitrary nesting. |
| 115 | + return if hasvalue(p.params, vn, dist) |
| 116 | + x = getvalue(p.params, vn, dist) |
| 117 | + if x === missing |
| 118 | + p.fallback === nothing && |
| 119 | + error("A `missing` value was provided for the variable `$(vn)`.") |
| 120 | + init(rng, vn, dist, p.fallback) |
| 121 | + else |
| 122 | + # TODO(penelopeysm): Since x is user-supplied, maybe we could also |
| 123 | + # check here that the type / size of x matches the dist? |
| 124 | + x |
| 125 | + end |
| 126 | + else |
| 127 | + p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") |
| 128 | + init(rng, vn, dist, p.fallback) |
| 129 | + end |
| 130 | +end |
| 131 | + |
| 132 | +""" |
| 133 | + InitContext( |
| 134 | + [rng::Random.AbstractRNG=Random.default_rng()], |
| 135 | + [strategy::AbstractInitStrategy=InitFromPrior()], |
| 136 | + ) |
| 137 | +
|
| 138 | +A leaf context that indicates that new values for random variables are |
| 139 | +currently being obtained through sampling. Used e.g. when initialising a fresh |
| 140 | +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then |
| 141 | +`evaluate!!(model, varinfo)` will override all values in the VarInfo. |
| 142 | +""" |
| 143 | +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext |
| 144 | + rng::R |
| 145 | + strategy::S |
| 146 | + function InitContext( |
| 147 | + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior() |
| 148 | + ) |
| 149 | + return new{typeof(rng),typeof(strategy)}(rng, strategy) |
| 150 | + end |
| 151 | + function InitContext(strategy::AbstractInitStrategy=InitFromPrior()) |
| 152 | + return InitContext(Random.default_rng(), strategy) |
| 153 | + end |
| 154 | +end |
| 155 | +NodeTrait(::InitContext) = IsLeaf() |
| 156 | + |
| 157 | +function tilde_assume( |
| 158 | + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo |
| 159 | +) |
| 160 | + in_varinfo = haskey(vi, vn) |
| 161 | + # `init()` always returns values in original space, i.e. possibly |
| 162 | + # constrained |
| 163 | + x = init(ctx.rng, vn, dist, ctx.strategy) |
| 164 | + # Determine whether to insert a transformed value into the VarInfo. |
| 165 | + # If the VarInfo alrady had a value for this variable, we will |
| 166 | + # keep the same linked status as in the original VarInfo. If not, we |
| 167 | + # check the rest of the VarInfo to see if other variables are linked. |
| 168 | + # istrans(vi) returns true if vi is nonempty and all variables in vi |
| 169 | + # are linked. |
| 170 | + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) |
| 171 | + f = if insert_transformed_value |
| 172 | + link_transform(dist) |
| 173 | + else |
| 174 | + identity |
| 175 | + end |
| 176 | + y, logjac = with_logabsdet_jacobian(f, x) |
| 177 | + # Add the new value to the VarInfo. `push!!` errors if the value already |
| 178 | + # exists, hence the need for setindex!!. |
| 179 | + if in_varinfo |
| 180 | + vi = setindex!!(vi, y, vn) |
| 181 | + else |
| 182 | + vi = push!!(vi, vn, y, dist) |
| 183 | + end |
| 184 | + # Neither of these set the `trans` flag so we have to do it manually if |
| 185 | + # necessary. |
| 186 | + insert_transformed_value && settrans!!(vi, true, vn) |
| 187 | + # `accumulate_assume!!` wants untransformed values as the second argument. |
| 188 | + vi = accumulate_assume!!(vi, x, logjac, vn, dist) |
| 189 | + # We always return the untransformed value here, as that will determine |
| 190 | + # what the lhs of the tilde-statement is set to. |
| 191 | + return x, vi |
| 192 | +end |
| 193 | + |
| 194 | +function tilde_observe!!(::InitContext, right, left, vn, vi) |
| 195 | + return tilde_observe!!(DefaultContext(), right, left, vn, vi) |
| 196 | +end |
0 commit comments