Skip to content

Commit 3d26226

Browse files
committed
Use init!! for initialisation
1 parent 23b569d commit 3d26226

File tree

1 file changed

+32
-116
lines changed

1 file changed

+32
-116
lines changed

src/sampler.jl

Lines changed: 32 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ end
6767
6868
Return a default varinfo object for the given `model` and `sampler`.
6969
70+
The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
71+
7072
# Arguments
7173
- `rng::Random.AbstractRNG`: Random number generator.
7274
- `model::Model`: Model for which we want to create a varinfo object.
@@ -75,9 +77,10 @@ Return a default varinfo object for the given `model` and `sampler`.
7577
# Returns
7678
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
7779
"""
78-
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
79-
init_sampler = initialsampler(sampler)
80-
return typed_varinfo(rng, model, init_sampler)
80+
function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler)
81+
# Note that variable values are unconditionally initialized later, so no
82+
# point putting them in now.
83+
return typed_varinfo(VarInfo())
8184
end
8285

8386
function AbstractMCMC.sample(
@@ -95,24 +98,32 @@ function AbstractMCMC.sample(
9598
)
9699
end
97100

98-
# initial step: general interface for resuming and
101+
"""
102+
init_strategy(sampler)
103+
104+
Define the initialisation strategy used for generating initial values when
105+
sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
106+
"""
107+
init_strategy(::Sampler) = PriorInit()
108+
99109
function AbstractMCMC.step(
100-
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
110+
rng::Random.AbstractRNG,
111+
model::Model,
112+
spl::Sampler;
113+
initial_params::AbstractInitStrategy=init_strategy(spl),
114+
kwargs...,
101115
)
102-
# Sample initial values.
116+
# Generate the default varinfo (usually this just makes an empty VarInfo
117+
# with NamedTuple of Metadata).
103118
vi = default_varinfo(rng, model, spl)
104119

105-
# Update the parameters if provided.
106-
if initial_params !== nothing
107-
vi = initialize_parameters!!(vi, initial_params, model)
108-
109-
# Update joint log probability.
110-
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
111-
# and https://github.com/TuringLang/Turing.jl/issues/1563
112-
# to avoid that existing variables are resampled
113-
vi = last(evaluate!!(model, vi))
114-
end
120+
# Fill it with initial parameters. Note that, if `ParamsInit` is used, the
121+
# parameters provided must be in unlinked space (when inserted into the
122+
# varinfo, they will be adjusted to match the linking status of the
123+
# varinfo).
124+
_, vi = init!!(rng, model, vi, initial_params)
115125

126+
# Call the actual function that does the first step.
116127
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
117128
end
118129

@@ -130,110 +141,15 @@ loadstate(data) = data
130141
131142
Default type of the chain of posterior samples from `sampler`.
132143
"""
133-
default_chain_type(sampler::Sampler) = Any
134-
135-
"""
136-
initialsampler(sampler::Sampler)
137-
138-
Return the sampler that is used for generating the initial parameters when sampling with
139-
`sampler`.
140-
141-
By default, it returns an instance of [`SampleFromPrior`](@ref).
142-
"""
143-
initialsampler(spl::Sampler) = SampleFromPrior()
144+
default_chain_type(::Sampler) = Any
144145

145146
"""
146-
set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
147-
set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
148-
149-
Take the values inside `initial_params`, replace the corresponding values in
150-
the given VarInfo object, and return a new VarInfo object with the updated values.
151-
152-
This differs from `DynamicPPL.unflatten` in two ways:
147+
init_strategy(sampler)
153148
154-
1. It works with `NamedTuple` arguments.
155-
2. For the `AbstractVector` method, if any of the elements are missing, it will not
156-
overwrite the original value in the VarInfo (it will just use the original
157-
value instead).
149+
Define the initialisation strategy used for generating initial values when
150+
sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
158151
"""
159-
function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
160-
throw(
161-
ArgumentError(
162-
"`initial_params` must be a vector of type `Union{Real,Missing}`. " *
163-
"If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.",
164-
),
165-
)
166-
end
167-
168-
function set_initial_values(
169-
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
170-
)
171-
flattened_param_vals = varinfo[:]
172-
length(flattened_param_vals) == length(initial_params) || throw(
173-
DimensionMismatch(
174-
"Provided initial value size ($(length(initial_params))) doesn't match " *
175-
"the model size ($(length(flattened_param_vals))).",
176-
),
177-
)
178-
179-
# Update values that are provided.
180-
for i in eachindex(initial_params)
181-
x = initial_params[i]
182-
if x !== missing
183-
flattened_param_vals[i] = x
184-
end
185-
end
186-
187-
# Update in `varinfo`.
188-
new_varinfo = unflatten(varinfo, flattened_param_vals)
189-
return new_varinfo
190-
end
191-
192-
function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
193-
varinfo = deepcopy(varinfo)
194-
vars_in_varinfo = keys(varinfo)
195-
for v in keys(initial_params)
196-
vn = VarName{v}()
197-
if !(vn in vars_in_varinfo)
198-
for vv in vars_in_varinfo
199-
if subsumes(vn, vv)
200-
throw(
201-
ArgumentError(
202-
"The current model contains sub-variables of $v, such as ($vv). " *
203-
"Using NamedTuple for initial_params is not supported in such a case. " *
204-
"Please use AbstractVector for initial_params instead of NamedTuple.",
205-
),
206-
)
207-
end
208-
end
209-
throw(ArgumentError("Variable $v not found in the model."))
210-
end
211-
end
212-
initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing)
213-
return update_values!!(
214-
varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params))
215-
)
216-
end
217-
218-
function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model)
219-
@debug "Using passed-in initial variable values" initial_params
220-
221-
# `link` the varinfo if needed.
222-
linked = islinked(vi)
223-
if linked
224-
vi = invlink!!(vi, model)
225-
end
226-
227-
# Set the values in `vi`.
228-
vi = set_initial_values(vi, initial_params)
229-
230-
# `invlink` if needed.
231-
if linked
232-
vi = link!!(vi, model)
233-
end
234-
235-
return vi
236-
end
152+
init_strategy(::Sampler) = PriorInit()
237153

238154
"""
239155
initialstep(rng, model, sampler, varinfo; kwargs...)

0 commit comments

Comments
 (0)