Skip to content

Commit 77b67d0

Browse files
committed
ENH Add sample_init() function and update examples.
1 parent 1770be2 commit 77b67d0

File tree

5 files changed

+380
-186
lines changed

5 files changed

+380
-186
lines changed

docs/source/notebooks/GLM-hierarchical.ipynb

Lines changed: 110 additions & 39 deletions
Large diffs are not rendered by default.

docs/source/notebooks/NUTS_scaling_using_ADVI.ipynb

Lines changed: 126 additions & 58 deletions
Large diffs are not rendered by default.

docs/source/notebooks/stochastic_volatility.ipynb

Lines changed: 52 additions & 88 deletions
Large diffs are not rendered by default.

pymc3/sampling.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sys
1717
sys.setrecursionlimit(10000)
1818

19-
__all__ = ['sample', 'iter_sample', 'sample_ppc']
19+
__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_init']
2020

2121

2222
def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropolis,
@@ -373,3 +373,80 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_see
373373
size=size))
374374

375375
return {k: np.asarray(v) for k, v in ppc.items()}
376+
377+
378+
def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
379+
model=None, **kwargs):
380+
"""Initialize and sample from posterior of a continuous model.
381+
382+
This is a convenience function. NUTS convergence and sampling speed is extremely
383+
dependent on the choice of mass/scaling matrix. In our experience, using ADVI
384+
to estimate a diagonal covariance matrix and using this as the scaling matrix
385+
produces robust results over a wide class of continuous models.
386+
387+
Parameteres
388+
-----------
389+
init : str {'advi', 'map', 'metropolis', 'nuts'}
390+
Initialization method to use.
391+
n_init : int
392+
Number of iterations of initializer
393+
If 'advi', number of iterations, if 'metropolis', number of draws.
394+
sampler : str {'nuts', 'hmc', advi'}
395+
Sampler to use. Will be initialized using init algorithm.
396+
draws : int
397+
Number of posterior samples to draw.
398+
njobs : int
399+
Number of parallel jobs to start. If None, set to number of cpus
400+
in the system - 2.
401+
**kwargs : additional keyword argumemts
402+
Additional keyword argumemts are forwared to pymc3.sample()
403+
404+
Returns
405+
-------
406+
MultiTrace object with access to sampling values
407+
"""
408+
409+
model = pm.modelcontext(model)
410+
pm._log.info('Initializing using {}...'.format(init))
411+
412+
if init == 'advi':
413+
v_params = pm.variational.advi(n=n_init)
414+
start = v_params.means
415+
cov = np.diagflat(np.power(model.dict_to_array(v_params.stds), 2))
416+
417+
elif init == 'map':
418+
start = pm.find_MAP()
419+
cov = pm.find_hessian(point=start)
420+
421+
elif init == 'metropolis':
422+
init_trace = pm.sample(step=pm.Metropolis(), draws=n_init)
423+
cov = pm.trace_cov(init_trace)
424+
425+
start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames}
426+
elif init == 'nuts':
427+
init_trace = pm.sample(step=pm.NUTS(), draws=n_init)
428+
cov = pm.trace_cov(init_trace)
429+
430+
start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames}
431+
else:
432+
raise NotImplemented('Initializer {} is not supported.'.format(init))
433+
434+
pm._log.info('Sampling using {}...'.format(sampler))
435+
if sampler == 'nuts':
436+
step = pm.NUTS(scaling=cov, is_cov=True)
437+
elif sampler == 'hmc':
438+
step = pm.HamiltonianMC(scaling=cov, is_cov=True)
439+
elif sampler == 'metropolis':
440+
step = pm.Metropolis(scaling=cov,
441+
proposal=pm.step_methods.metropolis.MultivariateNormalProposal)
442+
elif sampler != 'advi':
443+
raise NotImplemented('Sampler {} is not supported.'.format(init))
444+
445+
if sampler == 'advi':
446+
if init != 'advi':
447+
raise ValueError("To sample via ADVI, you have to set init='advi'.")
448+
trace = pm.variational.sample_vp(v_params, draws=draws)
449+
else:
450+
trace = pm.sample(step=step, start=start, draws=draws, **kwargs)
451+
452+
return trace

pymc3/tests/test_sampling.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,20 @@ def test_sample(self):
6161
for steps in [1, 10, 300]:
6262
pm.sample(steps, self.step, {}, None, njobs=njobs, random_seed=self.random_seed)
6363

64+
def test_sample_init(self):
65+
with self.model:
66+
for init in ('advi', 'map', 'metropolis', 'nuts'):
67+
for sampler in ('nuts', 'hmc', 'advi'):
68+
if (sampler == 'advi') and (init != 'advi'):
69+
self.assertRaises(ValueError, pm.sample_init,
70+
init=init, sampler=sampler,
71+
n_init=1000)
72+
else:
73+
pm.sample_init(init=init, sampler=sampler,
74+
n_init=1000, draws=50,
75+
random_seed=self.random_seed)
76+
77+
6478
def test_iter_sample(self):
6579
with self.model:
6680
samps = pm.sampling.iter_sample(5, self.step, self.start, random_seed=self.random_seed)

0 commit comments

Comments
 (0)