Skip to content

Commit ed2aa49

Browse files
aseyboldtJunpeng Lao
authored andcommitted
Change nuts init default to jitter+adapt_diag (#2475)
1 parent f581520 commit ed2aa49

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

pymc3/sampling.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,15 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
123123
Initialization method to use for auto-assigned NUTS samplers.
124124
125125
* auto : Choose a default initialization method automatically.
126-
Currently, this is `'advi+adapt_diag'`, but this can change in
126+
Currently, this is `'unif+adapt_diag'`, but this can change in
127127
the future. If you depend on the exact behaviour, choose an
128128
initialization method explicitly.
129129
* adapt_diag : Start with a identity mass matrix and then adapt
130-
a diagonal based on the variance of the tuning samples.
130+
a diagonal based on the variance of the tuning samples. All
131+
chains use the test value (usually the prior mean) as starting
132+
point.
133+
* jitter+adapt_diag : Same as `adapt_diag`, but add uniform jitter
134+
in [-1, 1] to the starting point in each chain.
131135
* advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
132136
mass matrix based on the sample variance of the tuning samples.
133137
* advi+adapt_diag_grad : Run ADVI and then adapt the resulting
@@ -695,11 +699,15 @@ def init_nuts(init='auto', njobs=1, n_init=500000, model=None,
695699
Initialization method to use.
696700
697701
* auto : Choose a default initialization method automatically.
698-
Currently, this is `'advi+adapt_diag'`, but this can change in
702+
Currently, this is `'unif+adapt_diag'`, but this can change in
699703
the future. If you depend on the exact behaviour, choose an
700704
initialization method explicitly.
701705
* adapt_diag : Start with a identity mass matrix and then adapt
702-
a diagonal based on the variance of the tuning samples.
706+
a diagonal based on the variance of the tuning samples. All
707+
chains use the test value (usually the prior mean) as starting
708+
point.
709+
* jitter+adapt_diag : Same as `adapt_diag`, but add uniform jitter
710+
in [-1, 1] to the starting point in each chain.
703711
* advi+adapt_diag : Run ADVI and then adapt the resulting diagonal
704712
mass matrix based on the sample variance of the tuning samples.
705713
* advi+adapt_diag_grad : Run ADVI and then adapt the resulting
@@ -746,7 +754,7 @@ def init_nuts(init='auto', njobs=1, n_init=500000, model=None,
746754
init = init.lower()
747755

748756
if init == 'auto':
749-
init = 'advi+adapt_diag'
757+
init = 'jitter+adapt_diag'
750758

751759
pm._log.info('Initializing NUTS using {}...'.format(init))
752760

@@ -767,6 +775,19 @@ def init_nuts(init='auto', njobs=1, n_init=500000, model=None,
767775
model.ndim, mean, var, 10)
768776
if njobs == 1:
769777
start = start[0]
778+
elif init == 'jitter+adapt_diag':
779+
start = []
780+
for _ in range(njobs):
781+
mean = {var: val.copy() for var, val in model.test_point.items()}
782+
for val in mean.values():
783+
val[...] += 2 * np.random.rand(*val.shape) - 1
784+
start.append(mean)
785+
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
786+
var = np.ones_like(mean)
787+
potential = quadpotential.QuadPotentialDiagAdapt(
788+
model.ndim, mean, var, 10)
789+
if njobs == 1:
790+
start = start[0]
770791
elif init == 'advi+adapt_diag_grad':
771792
approx = pm.fit(
772793
random_seed=random_seed,

pymc3/tests/test_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ def test_sum_normal(self):
253253

254254

255255
@pytest.mark.parametrize('method', [
256-
'adapt_diag', 'advi', 'ADVI+adapt_diag', 'advi+adapt_diag_grad',
257-
'map', 'advi_map', 'nuts'
256+
'jitter+adapt_diag', 'adapt_diag', 'advi', 'ADVI+adapt_diag',
257+
'advi+adapt_diag_grad', 'map', 'advi_map', 'nuts'
258258
])
259259
def test_exec_nuts_init(method):
260260
with pm.Model() as model:

0 commit comments

Comments
 (0)