@@ -41,7 +41,7 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL.
41
41
provided that supports resuming sampling from a previous state and setting initial
42
42
parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
43
43
for loading previous states and actually performing the initial sampling step,
44
- respectively. Additionally, sometimes one might want to implement [`initialsampler `](@ref)
44
+ respectively. Additionally, sometimes one might want to implement [`init_strategy `](@ref)
45
45
that specifies how the initial parameter values are sampled if they are not provided.
46
46
By default, values are sampled from the prior.
47
47
"""
68
68
69
69
Return a default varinfo object for the given `model` and `sampler`.
70
70
71
+ The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
72
+
71
73
# Arguments
72
74
- `rng::Random.AbstractRNG`: Random number generator.
73
75
- `model::Model`: Model for which we want to create a varinfo object.
@@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`.
76
78
# Returns
77
79
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
78
80
"""
79
- function default_varinfo (rng:: Random.AbstractRNG , model:: Model , sampler:: AbstractSampler )
80
- init_sampler = initialsampler (sampler)
81
- return typed_varinfo (rng, model, init_sampler)
81
+ function default_varinfo (:: Random.AbstractRNG , :: Model , :: AbstractSampler )
82
+ # Note that variable values are unconditionally initialized later, so no
83
+ # point putting them in now.
84
+ return typed_varinfo (VarInfo ())
82
85
end
83
86
84
87
function AbstractMCMC. sample (
@@ -96,24 +99,32 @@ function AbstractMCMC.sample(
96
99
)
97
100
end
98
101
99
- # initial step: general interface for resuming and
102
+ """
103
+ init_strategy(sampler)
104
+
105
+ Define the initialisation strategy used for generating initial values when
106
+ sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
107
+ """
108
+ init_strategy (:: Sampler ) = PriorInit ()
109
+
100
110
function AbstractMCMC. step (
101
- rng:: Random.AbstractRNG , model:: Model , spl:: Sampler ; initial_params= nothing , kwargs...
111
+ rng:: Random.AbstractRNG ,
112
+ model:: Model ,
113
+ spl:: Sampler ;
114
+ initial_params:: AbstractInitStrategy = init_strategy (spl),
115
+ kwargs... ,
102
116
)
103
- # Sample initial values.
117
+ # Generate the default varinfo (usually this just makes an empty VarInfo
118
+ # with NamedTuple of Metadata).
104
119
vi = default_varinfo (rng, model, spl)
105
120
106
- # Update the parameters if provided.
107
- if initial_params != = nothing
108
- vi = initialize_parameters!! (vi, initial_params, model)
109
-
110
- # Update joint log probability.
111
- # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
112
- # and https://github.com/TuringLang/Turing.jl/issues/1563
113
- # to avoid that existing variables are resampled
114
- vi = last (evaluate!! (model, vi))
115
- end
121
+ # Fill it with initial parameters. Note that, if `ParamsInit` is used, the
122
+ # parameters provided must be in unlinked space (when inserted into the
123
+ # varinfo, they will be adjusted to match the linking status of the
124
+ # varinfo).
125
+ _, vi = init!! (rng, model, vi, initial_params)
116
126
127
+ # Call the actual function that does the first step.
117
128
return initialstep (rng, model, spl, vi; initial_params, kwargs... )
118
129
end
119
130
@@ -131,110 +142,7 @@ loadstate(data) = data
131
142
132
143
Default type of the chain of posterior samples from `sampler`.
133
144
"""
134
- default_chain_type (sampler:: Sampler ) = Any
135
-
136
- """
137
- initialsampler(sampler::Sampler)
138
-
139
- Return the sampler that is used for generating the initial parameters when sampling with
140
- `sampler`.
141
-
142
- By default, it returns an instance of [`SampleFromPrior`](@ref).
143
- """
144
- initialsampler (spl:: Sampler ) = SampleFromPrior ()
145
-
146
- """
147
- set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
148
- set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
149
-
150
- Take the values inside `initial_params`, replace the corresponding values in
151
- the given VarInfo object, and return a new VarInfo object with the updated values.
152
-
153
- This differs from `DynamicPPL.unflatten` in two ways:
154
-
155
- 1. It works with `NamedTuple` arguments.
156
- 2. For the `AbstractVector` method, if any of the elements are missing, it will not
157
- overwrite the original value in the VarInfo (it will just use the original
158
- value instead).
159
- """
160
- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: AbstractVector )
161
- throw (
162
- ArgumentError (
163
- " `initial_params` must be a vector of type `Union{Real,Missing}`. " *
164
- " If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first." ,
165
- ),
166
- )
167
- end
168
-
169
- function set_initial_values (
170
- varinfo:: AbstractVarInfo , initial_params:: AbstractVector{<:Union{Real,Missing}}
171
- )
172
- flattened_param_vals = varinfo[:]
173
- length (flattened_param_vals) == length (initial_params) || throw (
174
- DimensionMismatch (
175
- " Provided initial value size ($(length (initial_params)) ) doesn't match " *
176
- " the model size ($(length (flattened_param_vals)) )." ,
177
- ),
178
- )
179
-
180
- # Update values that are provided.
181
- for i in eachindex (initial_params)
182
- x = initial_params[i]
183
- if x != = missing
184
- flattened_param_vals[i] = x
185
- end
186
- end
187
-
188
- # Update in `varinfo`.
189
- new_varinfo = unflatten (varinfo, flattened_param_vals)
190
- return new_varinfo
191
- end
192
-
193
- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: NamedTuple )
194
- varinfo = deepcopy (varinfo)
195
- vars_in_varinfo = keys (varinfo)
196
- for v in keys (initial_params)
197
- vn = VarName {v} ()
198
- if ! (vn in vars_in_varinfo)
199
- for vv in vars_in_varinfo
200
- if subsumes (vn, vv)
201
- throw (
202
- ArgumentError (
203
- " The current model contains sub-variables of $v , such as ($vv ). " *
204
- " Using NamedTuple for initial_params is not supported in such a case. " *
205
- " Please use AbstractVector for initial_params instead of NamedTuple." ,
206
- ),
207
- )
208
- end
209
- end
210
- throw (ArgumentError (" Variable $v not found in the model." ))
211
- end
212
- end
213
- initial_params = NamedTuple (k => v for (k, v) in pairs (initial_params) if v != = missing )
214
- return update_values!! (
215
- varinfo, initial_params, map (k -> VarName {k} (), keys (initial_params))
216
- )
217
- end
218
-
219
- function initialize_parameters!! (vi:: AbstractVarInfo , initial_params, model:: Model )
220
- @debug " Using passed-in initial variable values" initial_params
221
-
222
- # `link` the varinfo if needed.
223
- linked = islinked (vi)
224
- if linked
225
- vi = invlink!! (vi, model)
226
- end
227
-
228
- # Set the values in `vi`.
229
- vi = set_initial_values (vi, initial_params)
230
-
231
- # `invlink` if needed.
232
- if linked
233
- vi = link!! (vi, model)
234
- end
235
-
236
- return vi
237
- end
145
+ default_chain_type (:: Sampler ) = Any
238
146
239
147
"""
240
148
initialstep(rng, model, sampler, varinfo; kwargs...)
0 commit comments