Skip to content

Commit 85b46f6

Browse files
committed
MAINT Remove Metropolis init as it does not seem to work well.
1 parent dc203d9 commit 85b46f6

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

pymc3/sampling.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
386386
387387
Parameteres
388388
-----------
389-
init : str {'advi', 'map', 'metropolis', 'nuts'}
389+
init : str {'advi', 'map', 'nuts'}
390390
Initialization method to use.
391391
n_init : int
392392
Number of iterations of initializer
@@ -412,20 +412,15 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
412412
if init == 'advi':
413413
v_params = pm.variational.advi(n=n_init)
414414
start = v_params.means
415-
cov = np.diagflat(np.power(model.dict_to_array(v_params.stds), 2))
415+
cov = np.power(model.dict_to_array(v_params.stds), 2)
416416

417417
elif init == 'map':
418418
start = pm.find_MAP()
419419
cov = pm.find_hessian(point=start)
420420

421-
elif init == 'metropolis':
422-
init_trace = pm.sample(step=pm.Metropolis(), draws=n_init)
423-
cov = pm.trace_cov(init_trace)
424-
425-
start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames}
426421
elif init == 'nuts':
427422
init_trace = pm.sample(step=pm.NUTS(), draws=n_init)
428-
cov = pm.trace_cov(init_trace)
423+
cov = pm.trace_cov(init_trace[n_init//2:])
429424

430425
start = {varname: np.mean(init_trace[varname]) for varname in init_trace.varnames}
431426
else:
@@ -436,9 +431,6 @@ def sample_init(draws=2000, init='advi', n_init=500000, sampler='nuts',
436431
step = pm.NUTS(scaling=cov, is_cov=True)
437432
elif sampler == 'hmc':
438433
step = pm.HamiltonianMC(scaling=cov, is_cov=True)
439-
elif sampler == 'metropolis':
440-
step = pm.Metropolis(scaling=cov,
441-
proposal=pm.step_methods.metropolis.MultivariateNormalProposal)
442434
elif sampler != 'advi':
443435
raise NotImplemented('Sampler {} is not supported.'.format(init))
444436

pymc3/tests/test_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_sample(self):
6363

6464
def test_sample_init(self):
6565
with self.model:
66-
for init in ('advi', 'map', 'metropolis', 'nuts'):
66+
for init in ('advi', 'map', 'nuts'):
6767
for sampler in ('nuts', 'hmc', 'advi'):
6868
if (sampler == 'advi') and (init != 'advi'):
6969
self.assertRaises(ValueError, pm.sample_init,

0 commit comments

Comments
 (0)