|
16 | 16 | import sys
|
17 | 17 | sys.setrecursionlimit(10000)
|
18 | 18 |
|
19 |
| -__all__ = ['sample', 'iter_sample', 'sample_ppc'] |
| 19 | +__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_init'] |
20 | 20 |
|
21 | 21 |
|
22 | 22 | def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropolis,
|
@@ -373,3 +373,80 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_see
|
373 | 373 | size=size))
|
374 | 374 |
|
375 | 375 | return {k: np.asarray(v) for k, v in ppc.items()}
|
| 376 | + |
| 377 | + |
| 378 | +def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts', |
| 379 | + model=None, **kwargs): |
| 380 | + """Initialize and sample from posterior of a continuous model. |
| 381 | +
|
| 382 | + This is a convenience function. NUTS convergence and sampling speed is extremely |
| 383 | + dependent on the choice of mass/scaling matrix. In our experience, using ADVI |
| 384 | + to estimate a diagonal covariance matrix and using this as the scaling matrix |
| 385 | + produces robust results over a wide class of continuous models. |
| 386 | +
|
| 387 | + Parameteres |
| 388 | + ----------- |
| 389 | + init : str {'advi', 'map', 'metropolis', 'nuts'} |
| 390 | + Initialization method to use. |
| 391 | + n_init : int |
| 392 | + Number of iterations of initializer |
| 393 | + If 'advi', number of iterations, if 'metropolis', number of draws. |
| 394 | + sampler : str {'nuts', 'hmc', advi'} |
| 395 | + Sampler to use. Will be initialized using init algorithm. |
| 396 | + draws : int |
| 397 | + Number of posterior samples to draw. |
| 398 | + njobs : int |
| 399 | + Number of parallel jobs to start. If None, set to number of cpus |
| 400 | + in the system - 2. |
| 401 | + **kwargs : additional keyword argumemts |
| 402 | + Additional keyword argumemts are forwared to pymc3.sample() |
| 403 | +
|
| 404 | + Returns |
| 405 | + ------- |
| 406 | + MultiTrace object with access to sampling values |
| 407 | + """ |
| 408 | + |
| 409 | + model = pm.modelcontext(model) |
| 410 | + pm._log.info('Initializing using {}...'.format(init)) |
| 411 | + |
| 412 | + if init == 'advi': |
| 413 | + v_params = pm.variational.advi(n=n_init) |
| 414 | + start = v_params.means |
| 415 | + cov = np.diagflat(np.power(model.dict_to_array(v_params.stds), 2)) |
| 416 | + |
| 417 | + elif init == 'map': |
| 418 | + start = pm.find_MAP() |
| 419 | + cov = pm.find_hessian(point=start) |
| 420 | + |
| 421 | + elif init == 'metropolis': |
| 422 | + init_trace = pm.sample(step=pm.Metropolis(), draws=n_init) |
| 423 | + cov = pm.trace_cov(init_trace) |
| 424 | + |
| 425 | + start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames} |
| 426 | + elif init == 'nuts': |
| 427 | + init_trace = pm.sample(step=pm.NUTS(), draws=n_init) |
| 428 | + cov = pm.trace_cov(init_trace) |
| 429 | + |
| 430 | + start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames} |
| 431 | + else: |
| 432 | + raise NotImplemented('Initializer {} is not supported.'.format(init)) |
| 433 | + |
| 434 | + pm._log.info('Sampling using {}...'.format(sampler)) |
| 435 | + if sampler == 'nuts': |
| 436 | + step = pm.NUTS(scaling=cov, is_cov=True) |
| 437 | + elif sampler == 'hmc': |
| 438 | + step = pm.HamiltonianMC(scaling=cov, is_cov=True) |
| 439 | + elif sampler == 'metropolis': |
| 440 | + step = pm.Metropolis(scaling=cov, |
| 441 | + proposal=pm.step_methods.metropolis.MultivariateNormalProposal) |
| 442 | + elif sampler != 'advi': |
| 443 | + raise NotImplemented('Sampler {} is not supported.'.format(init)) |
| 444 | + |
| 445 | + if sampler == 'advi': |
| 446 | + if init != 'advi': |
| 447 | + raise ValueError("To sample via ADVI, you have to set init='advi'.") |
| 448 | + trace = pm.variational.sample_vp(v_params, draws=draws) |
| 449 | + else: |
| 450 | + trace = pm.sample(step=step, start=start, draws=draws, **kwargs) |
| 451 | + |
| 452 | + return trace |
0 commit comments