1717
1818# optimal tuning for HMC, see https://arxiv.org/abs/1001.4460
1919OPTIMAL_TARGET_ACCEPTANCE_RATE = 0.651
20+ # Clip the final log-space update like the original implementation in TFP (~log(2)/2 ≈ 0.35).
21+ LOG_UPDATE_CLIP = 0.35
22+ # Small constant to avoid division by zero or log of zero
23+ EPS_FLOAT = 1e-20
2024
2125
2226class ChEESAdaptationState (NamedTuple ):
@@ -52,12 +56,24 @@ class ChEESAdaptationState(NamedTuple):
5256 step : int
5357
5458
59+ def weighted_empirical_mean (x , w ):
60+ # x: (num_chains, dim), w: (num_chains,)
61+ x_safe = jnp .where (jnp .isfinite (x ), x , 0.0 )
62+ w = jnp .where (jnp .isfinite (x ).all (axis = - 1 ), w , 0.0 )
63+
64+ w_exp = w .reshape ((w .shape [0 ],) + (1 ,) * (x .ndim - 1 ))
65+ num = jnp .sum (w_exp * x_safe , axis = 0 )
66+ den = jnp .sum (w_exp , axis = 0 ) + EPS_FLOAT
67+ return jax .lax .stop_gradient (num / den )
68+
69+
5570def base (
5671 jitter_generator : Callable ,
5772 next_random_arg_fn : Callable ,
5873 optim : optax .GradientTransformation ,
5974 target_acceptance_rate : float ,
6075 decay_rate : float ,
76+ max_leapfrog_steps : int ,
6177) -> Tuple [Callable , Callable ]:
6278 """Maximizing the Change in the Estimator of the Expected Square criterion
6379 (trajectory length) and dual averaging procedure (step size) for the jittered
@@ -144,6 +160,8 @@ def compute_parameters(
144160 harmonic_mean = 1.0 / jnp .mean (
145161 1.0 / acceptance_probabilities , where = ~ is_divergent
146162 )
163+ # Replace inf/nan harmonic mean as zero to avoid issues in dual averaging
164+ harmonic_mean = jnp .where (jnp .isfinite (harmonic_mean ), harmonic_mean , 0.0 )
147165 da_state_ = da_update (da_state , target_acceptance_rate - harmonic_mean )
148166 step_size_ = jnp .exp (da_state_ .log_x )
149167 new_step_size , new_da_state , new_log_step_size = jax .lax .cond (
@@ -157,9 +175,14 @@ def compute_parameters(
157175 1.0 - update_weight
158176 ) * log_step_size_ma + update_weight * new_log_step_size
159177
178+ w = jnp .where (~ is_divergent , acceptance_probabilities , 0.0 )
160179 proposals_mean = jax .tree_util .tree_map (
161- lambda p : jnp . nanmean (p , axis = 0 ), proposed_positions
180+ lambda p : weighted_empirical_mean (p , w ), proposed_positions
162181 )
182+ # The above weighted mean is presumably better than the simple mean:
183+ # proposals_mean = jax.tree_util.tree_map(
184+ # lambda p: jnp.nanmean(p, axis=0), proposed_positions
185+ # )
163186 initials_mean = jax .tree_util .tree_map (
164187 lambda p : jnp .nanmean (p , axis = 0 ), initial_positions
165188 )
@@ -177,19 +200,25 @@ def compute_parameters(
177200
178201 trajectory_gradients = (
179202 jitter_generator (random_generator_arg )
180- * trajectory_length
203+ * trajectory_length # this effectively make this gradient w.r.t. log_trajectory_length
181204 * jax .vmap (
182205 lambda pm , im , mm : (jnp .dot (pm , pm ) - jnp .dot (im , im )) * jnp .dot (pm , mm )
183206 )(proposals_matrix , initials_matrix , momentums_matrix )
184207 )
208+
185209 trajectory_gradient = jnp .sum (
186- acceptance_probabilities * trajectory_gradients , where = ~ is_divergent
187- ) / jnp .sum (acceptance_probabilities , where = ~ is_divergent )
210+ acceptance_probabilities * trajectory_gradients ,
211+ where = ~ is_divergent ,
212+ ) / jnp .sum (acceptance_probabilities + EPS_FLOAT , where = ~ is_divergent )
188213
189214 log_trajectory_length = jnp .log (trajectory_length )
190215 updates , optim_state_ = optim .update (
191216 trajectory_gradient , optim_state , log_trajectory_length
192217 )
218+
219+ updates = jax .tree_util .tree_map (
220+ lambda u : jnp .clip (u , - LOG_UPDATE_CLIP , LOG_UPDATE_CLIP ), updates
221+ )
193222 log_trajectory_length_ = optax .apply_updates (log_trajectory_length , updates )
194223 new_log_trajectory_length , new_optim_state = jax .lax .cond (
195224 jnp .isfinite (
@@ -204,6 +233,13 @@ def compute_parameters(
204233 ) * log_trajectory_length_ma + update_weight * new_log_trajectory_length
205234 new_trajectory_length = jnp .exp (new_log_trajectory_length_ma )
206235
236+ # clip new trajectory length to avoid too large trajectories, also the
237+ # minimum trajectory length is one integrator step
238+ new_trajectory_length = jnp .clip (
239+ new_trajectory_length ,
240+ max = max_leapfrog_steps * new_step_size ,
241+ min = new_step_size ,
242+ )
207243 return ChEESAdaptationState (
208244 new_step_size ,
209245 new_log_step_size_ma ,
@@ -278,6 +314,7 @@ def chees_adaptation(
278314 jitter_amount : float = 1.0 ,
279315 target_acceptance_rate : float = OPTIMAL_TARGET_ACCEPTANCE_RATE ,
280316 decay_rate : float = 0.5 ,
317+ max_leapfrog_steps : int = 1000 ,
281318 adaptation_info_fn : Callable = return_all_adapt_info ,
282319) -> AdaptationAlgorithm :
283320 """Adapt the step size and trajectory length (number of integration steps / step size)
@@ -376,13 +413,14 @@ def run(
376413 jax .random .fold_in (carry_key , i )
377414 ) * jitter_amount + (1.0 - jitter_amount )
378415 else :
416+ max_bits = np .ceil (np .log2 (num_steps + max_sampling_steps ))
379417 jitter_gn = lambda i : dynamic_hmc .halton_sequence (
380- i , np . ceil ( np . log2 ( num_steps + max_sampling_steps ))
418+ i , max_bits
381419 ) * jitter_amount + (1.0 - jitter_amount )
382420
383- def integration_steps_fn (random_generator_arg , trajectory_length_adjusted ):
421+ def integration_steps_fn (random_generator_arg , num_leapfrog_steps ):
384422 return jnp .asarray (
385- jnp .ceil (jitter_gn (random_generator_arg ) * trajectory_length_adjusted ),
423+ jnp .ceil (jitter_gn (random_generator_arg ) * num_leapfrog_steps ),
386424 dtype = int ,
387425 )
388426
@@ -392,7 +430,12 @@ def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
392430 )
393431
394432 init , update = base (
395- jitter_gn , next_random_arg_fn , optim , target_acceptance_rate , decay_rate
433+ jitter_gn ,
434+ next_random_arg_fn ,
435+ optim ,
436+ target_acceptance_rate ,
437+ decay_rate ,
438+ max_leapfrog_steps ,
396439 )
397440
398441 def one_step (carry , rng_key ):
@@ -404,7 +447,7 @@ def one_step(carry, rng_key):
404447 logdensity_fn = logdensity_fn ,
405448 step_size = adaptation_state .step_size ,
406449 inverse_mass_matrix = jnp .ones (num_dim ),
407- trajectory_length_adjusted = adaptation_state .trajectory_length
450+ num_leapfrog_steps = adaptation_state .trajectory_length
408451 / adaptation_state .step_size ,
409452 )
410453 new_states , info = jax .vmap (_step_fn )(keys , states )
@@ -432,7 +475,7 @@ def one_step(carry, rng_key):
432475 one_step , (init_states , init_adaptation_state ), keys_step
433476 )
434477
435- trajectory_length_adjusted = jnp .exp (
478+ num_leapfrog_steps = jnp .exp (
436479 last_adaptation_state .log_trajectory_length_moving_average
437480 - last_adaptation_state .log_step_size_moving_average
438481 )
@@ -441,7 +484,7 @@ def one_step(carry, rng_key):
441484 "inverse_mass_matrix" : jnp .ones (num_dim ),
442485 "next_random_arg_fn" : next_random_arg_fn ,
443486 "integration_steps_fn" : lambda arg : integration_steps_fn (
444- arg , trajectory_length_adjusted
487+ arg , num_leapfrog_steps
445488 ),
446489 }
447490
0 commit comments