67
67
68
68
Return a default varinfo object for the given `model` and `sampler`.
69
69
70
+ The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo').
71
+
70
72
# Arguments
71
73
- `rng::Random.AbstractRNG`: Random number generator.
72
74
- `model::Model`: Model for which we want to create a varinfo object.
@@ -75,9 +77,10 @@ Return a default varinfo object for the given `model` and `sampler`.
75
77
# Returns
76
78
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
77
79
"""
78
- function default_varinfo (rng:: Random.AbstractRNG , model:: Model , sampler:: AbstractSampler )
79
- init_sampler = initialsampler (sampler)
80
- return typed_varinfo (rng, model, init_sampler)
80
+ function default_varinfo (:: Random.AbstractRNG , :: Model , :: AbstractSampler )
81
+ # Note that variable values are unconditionally initialized later, so no
82
+ # point putting them in now.
83
+ return typed_varinfo (VarInfo ())
81
84
end
82
85
83
86
function AbstractMCMC. sample (
@@ -95,24 +98,32 @@ function AbstractMCMC.sample(
95
98
)
96
99
end
97
100
98
- # initial step: general interface for resuming and
101
+ """
102
+ init_strategy(sampler)
103
+
104
+ Define the initialisation strategy used for generating initial values when
105
+ sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
106
+ """
107
+ init_strategy (:: Sampler ) = PriorInit ()
108
+
99
109
function AbstractMCMC. step (
100
- rng:: Random.AbstractRNG , model:: Model , spl:: Sampler ; initial_params= nothing , kwargs...
110
+ rng:: Random.AbstractRNG ,
111
+ model:: Model ,
112
+ spl:: Sampler ;
113
+ initial_params:: AbstractInitStrategy = init_strategy (spl),
114
+ kwargs... ,
101
115
)
102
- # Sample initial values.
116
+ # Generate the default varinfo (usually this just makes an empty VarInfo
117
+ # with NamedTuple of Metadata).
103
118
vi = default_varinfo (rng, model, spl)
104
119
105
- # Update the parameters if provided.
106
- if initial_params != = nothing
107
- vi = initialize_parameters!! (vi, initial_params, model)
108
-
109
- # Update joint log probability.
110
- # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
111
- # and https://github.com/TuringLang/Turing.jl/issues/1563
112
- # to avoid that existing variables are resampled
113
- vi = last (evaluate!! (model, vi))
114
- end
120
+ # Fill it with initial parameters. Note that, if `ParamsInit` is used, the
121
+ # parameters provided must be in unlinked space (when inserted into the
122
+ # varinfo, they will be adjusted to match the linking status of the
123
+ # varinfo).
124
+ _, vi = init!! (rng, model, vi, initial_params)
115
125
126
+ # Call the actual function that does the first step.
116
127
return initialstep (rng, model, spl, vi; initial_params, kwargs... )
117
128
end
118
129
@@ -130,110 +141,15 @@ loadstate(data) = data
130
141
131
142
Default type of the chain of posterior samples from `sampler`.
132
143
"""
133
- default_chain_type (sampler:: Sampler ) = Any
134
-
135
- """
136
- initialsampler(sampler::Sampler)
137
-
138
- Return the sampler that is used for generating the initial parameters when sampling with
139
- `sampler`.
140
-
141
- By default, it returns an instance of [`SampleFromPrior`](@ref).
142
- """
143
- initialsampler (spl:: Sampler ) = SampleFromPrior ()
144
+ default_chain_type (:: Sampler ) = Any
144
145
145
146
"""
146
- set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector)
147
- set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple)
148
-
149
- Take the values inside `initial_params`, replace the corresponding values in
150
- the given VarInfo object, and return a new VarInfo object with the updated values.
151
-
152
- This differs from `DynamicPPL.unflatten` in two ways:
147
+ init_strategy(sampler)
153
148
154
- 1. It works with `NamedTuple` arguments.
155
- 2. For the `AbstractVector` method, if any of the elements are missing, it will not
156
- overwrite the original value in the VarInfo (it will just use the original
157
- value instead).
149
+ Define the initialisation strategy used for generating initial values when
150
+ sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden.
158
151
"""
159
- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: AbstractVector )
160
- throw (
161
- ArgumentError (
162
- " `initial_params` must be a vector of type `Union{Real,Missing}`. " *
163
- " If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first." ,
164
- ),
165
- )
166
- end
167
-
168
- function set_initial_values (
169
- varinfo:: AbstractVarInfo , initial_params:: AbstractVector{<:Union{Real,Missing}}
170
- )
171
- flattened_param_vals = varinfo[:]
172
- length (flattened_param_vals) == length (initial_params) || throw (
173
- DimensionMismatch (
174
- " Provided initial value size ($(length (initial_params)) ) doesn't match " *
175
- " the model size ($(length (flattened_param_vals)) )." ,
176
- ),
177
- )
178
-
179
- # Update values that are provided.
180
- for i in eachindex (initial_params)
181
- x = initial_params[i]
182
- if x != = missing
183
- flattened_param_vals[i] = x
184
- end
185
- end
186
-
187
- # Update in `varinfo`.
188
- new_varinfo = unflatten (varinfo, flattened_param_vals)
189
- return new_varinfo
190
- end
191
-
192
- function set_initial_values (varinfo:: AbstractVarInfo , initial_params:: NamedTuple )
193
- varinfo = deepcopy (varinfo)
194
- vars_in_varinfo = keys (varinfo)
195
- for v in keys (initial_params)
196
- vn = VarName {v} ()
197
- if ! (vn in vars_in_varinfo)
198
- for vv in vars_in_varinfo
199
- if subsumes (vn, vv)
200
- throw (
201
- ArgumentError (
202
- " The current model contains sub-variables of $v , such as ($vv ). " *
203
- " Using NamedTuple for initial_params is not supported in such a case. " *
204
- " Please use AbstractVector for initial_params instead of NamedTuple." ,
205
- ),
206
- )
207
- end
208
- end
209
- throw (ArgumentError (" Variable $v not found in the model." ))
210
- end
211
- end
212
- initial_params = NamedTuple (k => v for (k, v) in pairs (initial_params) if v != = missing )
213
- return update_values!! (
214
- varinfo, initial_params, map (k -> VarName {k} (), keys (initial_params))
215
- )
216
- end
217
-
218
- function initialize_parameters!! (vi:: AbstractVarInfo , initial_params, model:: Model )
219
- @debug " Using passed-in initial variable values" initial_params
220
-
221
- # `link` the varinfo if needed.
222
- linked = islinked (vi)
223
- if linked
224
- vi = invlink!! (vi, model)
225
- end
226
-
227
- # Set the values in `vi`.
228
- vi = set_initial_values (vi, initial_params)
229
-
230
- # `invlink` if needed.
231
- if linked
232
- vi = link!! (vi, model)
233
- end
234
-
235
- return vi
236
- end
152
+ init_strategy (:: Sampler ) = PriorInit ()
237
153
238
154
"""
239
155
initialstep(rng, model, sampler, varinfo; kwargs...)
0 commit comments