|
1 |
| -from collections import defaultdict |
| 1 | +from collections import defaultdict, Sequence |
2 | 2 |
|
3 | 3 | from joblib import Parallel, delayed
|
4 | 4 | from numpy.random import randint, seed
|
@@ -144,7 +144,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
|
144 | 144 | n_init : int
|
145 | 145 | Number of iterations of initializer
|
146 | 146 | If 'ADVI', number of iterations, if 'nuts', number of draws.
|
147 |
| - start : dict |
| 147 | + start : dict, or array of dict |
148 | 148 | Starting point in parameter space (or partial point)
|
149 | 149 | Defaults to trace.point(-1)) if there is a trace provided and
|
150 | 150 | model.test_point if not (defaults to empty dict).
|
@@ -227,6 +227,9 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
|
227 | 227 | """
|
228 | 228 | model = modelcontext(model)
|
229 | 229 |
|
| 230 | + if start is not None: |
| 231 | + _check_start_shape(model, start) |
| 232 | + |
230 | 233 | draws += tune
|
231 | 234 |
|
232 | 235 | if nuts_kwargs is not None:
|
@@ -280,6 +283,38 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
|
280 | 283 | return sample_func(**sample_args)[discard:]
|
281 | 284 |
|
282 | 285 |
|
| 286 | +def _check_start_shape(model, start): |
| 287 | + e = '' |
| 288 | + if isinstance(start, (Sequence, np.ndarray)): |
| 289 | + # to deal with iterable start argument |
| 290 | + for start_iter in start: |
| 291 | + _check_start_shape(model, start_iter) |
| 292 | + return |
| 293 | + elif not isinstance(start, dict): |
| 294 | + raise TypeError("start argument must be a dict " |
| 295 | + "or an array-like of dicts") |
| 296 | + for var in model.vars: |
| 297 | + if var.name in start.keys(): |
| 298 | + var_shape = var.shape.tag.test_value |
| 299 | + start_var_shape = np.shape(start[var.name]) |
| 300 | + if start_var_shape: |
| 301 | + if not np.array_equal(var_shape, start_var_shape): |
| 302 | + e += "\nExpected shape {} for var '{}', got: {}".format( |
| 303 | + tuple(var_shape), var.name, start_var_shape |
| 304 | + ) |
| 305 | + # if start var has no shape |
| 306 | + else: |
| 307 | + # if model var has a specified shape |
| 308 | + if var_shape: |
| 309 | + e += "\nExpected shape {} for var " \ |
| 310 | + "'{}', got scalar {}".format( |
| 311 | + tuple(var_shape), var.name, start[var.name] |
| 312 | + ) |
| 313 | + |
| 314 | + if e != '': |
| 315 | + raise ValueError("Bad shape for start argument:{}".format(e)) |
| 316 | + |
| 317 | + |
283 | 318 | def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
|
284 | 319 | progressbar=True, model=None, random_seed=-1, live_plot=False,
|
285 | 320 | live_plot_kwargs=None, **kwargs):
|
|
0 commit comments