@@ -192,6 +192,7 @@ def init_kernel(init_params,
192192 target_accept_prob = 0.8 ,
193193 trajectory_length = 2 * math .pi ,
194194 max_tree_depth = 10 ,
195+ find_heuristic_step_size = False ,
195196 model_args = (),
196197 model_kwargs = None ,
197198 rng_key = PRNGKey (0 )):
@@ -221,6 +222,8 @@ def init_kernel(init_params,
221222 value is :math:`2\\ pi`.
222223 :param int max_tree_depth: Max depth of the binary tree created during the doubling
223224 scheme of NUTS sampler. Defaults to 10.
225+ :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
226+ step size at the beginning of each adaptation window. Defaults to False.
224227 :param tuple model_args: Model arguments if `potential_fn_gen` is specified.
225228 :param dict model_kwargs: Model keyword arguments if `potential_fn_gen` is specified.
226229 :param jax.random.PRNGKey rng_key: random key to be used as the source of
@@ -243,8 +246,12 @@ def init_kernel(init_params,
243246 kwargs = {} if model_kwargs is None else model_kwargs
244247 pe_fn = potential_fn_gen (* model_args , ** kwargs )
245248
246- find_reasonable_ss = partial (find_reasonable_step_size ,
247- pe_fn , kinetic_fn , momentum_generator )
249+ find_reasonable_ss = None
250+ if find_heuristic_step_size :
251+ find_reasonable_ss = partial (find_reasonable_step_size ,
252+ pe_fn ,
253+ kinetic_fn ,
254+ momentum_generator )
248255
249256 wa_init , wa_update = warmup_adapter (num_warmup ,
250257 adapt_step_size = adapt_step_size ,
@@ -437,6 +444,8 @@ class HMC(MCMCKernel):
437444 value is :math:`2\\ pi`.
438445 :param callable init_strategy: a per-site initialization function.
439446 See :ref:`init_strategy` section for available functions.
447+ :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
448+ step size at the beginning of each adaptation window. Defaults to False.
440449 """
441450 def __init__ (self ,
442451 model = None ,
@@ -448,7 +457,8 @@ def __init__(self,
448457 dense_mass = False ,
449458 target_accept_prob = 0.8 ,
450459 trajectory_length = 2 * math .pi ,
451- init_strategy = init_to_uniform ()):
460+ init_strategy = init_to_uniform (),
461+ find_heuristic_step_size = False ):
452462 if not (model is None ) ^ (potential_fn is None ):
453463 raise ValueError ('Only one of `model` or `potential_fn` must be specified.' )
454464 self ._model = model
@@ -463,6 +473,7 @@ def __init__(self,
463473 self ._algo = 'HMC'
464474 self ._max_tree_depth = 10
465475 self ._init_strategy = init_strategy
476+ self ._find_heuristic_step_size = find_heuristic_step_size
466477 # Set on first call to init
467478 self ._init_fn = None
468479 self ._postprocess_fn = None
@@ -525,9 +536,10 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg
525536 target_accept_prob = self ._target_accept_prob ,
526537 trajectory_length = self ._trajectory_length ,
527538 max_tree_depth = self ._max_tree_depth ,
528- rng_key = rng_key ,
539+ find_heuristic_step_size = self . _find_heuristic_step_size ,
529540 model_args = model_args ,
530541 model_kwargs = model_kwargs ,
542+ rng_key = rng_key ,
531543 )
532544 if rng_key .ndim == 1 :
533545 init_state = hmc_init_fn (init_params , rng_key )
@@ -600,6 +612,8 @@ class NUTS(HMC):
600612 scheme of NUTS sampler. Defaults to 10.
601613 :param callable init_strategy: a per-site initialization function.
602614 See :ref:`init_strategy` section for available functions.
615+ :param bool find_heuristic_step_size: whether to a heuristic function to adjust the
616+ step size at the beginning of each adaptation window. Defaults to False.
603617 """
604618 def __init__ (self ,
605619 model = None ,
@@ -612,12 +626,15 @@ def __init__(self,
612626 target_accept_prob = 0.8 ,
613627 trajectory_length = None ,
614628 max_tree_depth = 10 ,
615- init_strategy = init_to_uniform ()):
629+ init_strategy = init_to_uniform (),
630+ find_heuristic_step_size = False ):
616631 super (NUTS , self ).__init__ (potential_fn = potential_fn , model = model , kinetic_fn = kinetic_fn ,
617632 step_size = step_size , adapt_step_size = adapt_step_size ,
618633 adapt_mass_matrix = adapt_mass_matrix , dense_mass = dense_mass ,
619634 target_accept_prob = target_accept_prob ,
620- trajectory_length = trajectory_length , init_strategy = init_strategy )
635+ trajectory_length = trajectory_length ,
636+ init_strategy = init_strategy ,
637+ find_heuristic_step_size = find_heuristic_step_size )
621638 self ._max_tree_depth = max_tree_depth
622639 self ._algo = 'NUTS'
623640
0 commit comments