Skip to content

Commit 017ac03

Browse files
authored
Remove find_reasonable_step_size (#553)
1 parent cc3676f commit 017ac03

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

numpyro/infer/hmc_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def _identity_step_size(inverse_mass_matrix, z, rng_key, step_size):
328328
return step_size
329329

330330

331-
def warmup_adapter(num_adapt_steps, find_reasonable_step_size=_identity_step_size,
331+
def warmup_adapter(num_adapt_steps, find_reasonable_step_size=None,
332332
adapt_step_size=True, adapt_mass_matrix=True,
333333
dense_mass=False, target_accept_prob=0.8):
334334
"""
@@ -349,6 +349,8 @@ def warmup_adapter(num_adapt_steps, find_reasonable_step_size=_identity_step_siz
349349
step size, hence the sampling will be slower but more robust. Default to 0.8.
350350
:return: a pair of (`init_fn`, `update_fn`).
351351
"""
352+
if find_reasonable_step_size is None:
353+
find_reasonable_step_size = _identity_step_size
352354
ss_init, ss_update = dual_averaging()
353355
mm_init, mm_update, mm_final = welford_covariance(diagonal=not dense_mass)
354356
adaptation_schedule = np.array(build_adaptation_schedule(num_adapt_steps))

numpyro/infer/mcmc.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def init_kernel(init_params,
192192
target_accept_prob=0.8,
193193
trajectory_length=2*math.pi,
194194
max_tree_depth=10,
195+
find_heuristic_step_size=False,
195196
model_args=(),
196197
model_kwargs=None,
197198
rng_key=PRNGKey(0)):
@@ -221,6 +222,8 @@ def init_kernel(init_params,
221222
value is :math:`2\\pi`.
222223
:param int max_tree_depth: Max depth of the binary tree created during the doubling
223224
scheme of NUTS sampler. Defaults to 10.
225+
:param bool find_heuristic_step_size: whether to a heuristic function to adjust the
226+
step size at the beginning of each adaptation window. Defaults to False.
224227
:param tuple model_args: Model arguments if `potential_fn_gen` is specified.
225228
:param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
226229
:param jax.random.PRNGKey rng_key: random key to be used as the source of
@@ -243,8 +246,12 @@ def init_kernel(init_params,
243246
kwargs = {} if model_kwargs is None else model_kwargs
244247
pe_fn = potential_fn_gen(*model_args, **kwargs)
245248

246-
find_reasonable_ss = partial(find_reasonable_step_size,
247-
pe_fn, kinetic_fn, momentum_generator)
249+
find_reasonable_ss = None
250+
if find_heuristic_step_size:
251+
find_reasonable_ss = partial(find_reasonable_step_size,
252+
pe_fn,
253+
kinetic_fn,
254+
momentum_generator)
248255

249256
wa_init, wa_update = warmup_adapter(num_warmup,
250257
adapt_step_size=adapt_step_size,
@@ -437,6 +444,8 @@ class HMC(MCMCKernel):
437444
value is :math:`2\\pi`.
438445
:param callable init_strategy: a per-site initialization function.
439446
See :ref:`init_strategy` section for available functions.
447+
:param bool find_heuristic_step_size: whether to a heuristic function to adjust the
448+
step size at the beginning of each adaptation window. Defaults to False.
440449
"""
441450
def __init__(self,
442451
model=None,
@@ -448,7 +457,8 @@ def __init__(self,
448457
dense_mass=False,
449458
target_accept_prob=0.8,
450459
trajectory_length=2 * math.pi,
451-
init_strategy=init_to_uniform()):
460+
init_strategy=init_to_uniform(),
461+
find_heuristic_step_size=False):
452462
if not (model is None) ^ (potential_fn is None):
453463
raise ValueError('Only one of `model` or `potential_fn` must be specified.')
454464
self._model = model
@@ -463,6 +473,7 @@ def __init__(self,
463473
self._algo = 'HMC'
464474
self._max_tree_depth = 10
465475
self._init_strategy = init_strategy
476+
self._find_heuristic_step_size = find_heuristic_step_size
466477
# Set on first call to init
467478
self._init_fn = None
468479
self._postprocess_fn = None
@@ -525,9 +536,10 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg
525536
target_accept_prob=self._target_accept_prob,
526537
trajectory_length=self._trajectory_length,
527538
max_tree_depth=self._max_tree_depth,
528-
rng_key=rng_key,
539+
find_heuristic_step_size=self._find_heuristic_step_size,
529540
model_args=model_args,
530541
model_kwargs=model_kwargs,
542+
rng_key=rng_key,
531543
)
532544
if rng_key.ndim == 1:
533545
init_state = hmc_init_fn(init_params, rng_key)
@@ -600,6 +612,8 @@ class NUTS(HMC):
600612
scheme of NUTS sampler. Defaults to 10.
601613
:param callable init_strategy: a per-site initialization function.
602614
See :ref:`init_strategy` section for available functions.
615+
:param bool find_heuristic_step_size: whether to a heuristic function to adjust the
616+
step size at the beginning of each adaptation window. Defaults to False.
603617
"""
604618
def __init__(self,
605619
model=None,
@@ -612,12 +626,15 @@ def __init__(self,
612626
target_accept_prob=0.8,
613627
trajectory_length=None,
614628
max_tree_depth=10,
615-
init_strategy=init_to_uniform()):
629+
init_strategy=init_to_uniform(),
630+
find_heuristic_step_size=False):
616631
super(NUTS, self).__init__(potential_fn=potential_fn, model=model, kinetic_fn=kinetic_fn,
617632
step_size=step_size, adapt_step_size=adapt_step_size,
618633
adapt_mass_matrix=adapt_mass_matrix, dense_mass=dense_mass,
619634
target_accept_prob=target_accept_prob,
620-
trajectory_length=trajectory_length, init_strategy=init_strategy)
635+
trajectory_length=trajectory_length,
636+
init_strategy=init_strategy,
637+
find_heuristic_step_size=find_heuristic_step_size)
621638
self._max_tree_depth = max_tree_depth
622639
self._algo = 'NUTS'
623640

test/test_mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def model(labels):
8787
if kernel_cls is SA:
8888
kernel = SA(model=model, adapt_state_size=9)
8989
else:
90-
kernel = kernel_cls(model=model, trajectory_length=8)
90+
kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True)
9191
mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
9292
mcmc.run(random.PRNGKey(2), labels)
9393
mcmc.print_summary()

0 commit comments

Comments
 (0)