Skip to content

Commit 1b78027

Browse files
authored
Improve the robustness of ChEES-HMC. (#803)
* Improve the robustness of ChEES-HMC. * Add a test for halton_sequence to raise ValueError when max_bits is too large. * Move constant definitions to top of the file.
1 parent 2b31129 commit 1b78027

File tree

3 files changed

+104
-23
lines changed

3 files changed

+104
-23
lines changed

blackjax/adaptation/chees_adaptation.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
# optimal tuning for HMC, see https://arxiv.org/abs/1001.4460
1919
OPTIMAL_TARGET_ACCEPTANCE_RATE = 0.651
20+
# Clip the final log-space update like the original implementation in TFP (~log(2)/2 ≈ 0.35).
21+
LOG_UPDATE_CLIP = 0.35
22+
# Small constant to avoid division by zero or log of zero
23+
EPS_FLOAT = 1e-20
2024

2125

2226
class ChEESAdaptationState(NamedTuple):
@@ -52,12 +56,24 @@ class ChEESAdaptationState(NamedTuple):
5256
step: int
5357

5458

59+
def weighted_empirical_mean(x, w):
60+
# x: (num_chains, dim), w: (num_chains,)
61+
x_safe = jnp.where(jnp.isfinite(x), x, 0.0)
62+
w = jnp.where(jnp.isfinite(x).all(axis=-1), w, 0.0)
63+
64+
w_exp = w.reshape((w.shape[0],) + (1,) * (x.ndim - 1))
65+
num = jnp.sum(w_exp * x_safe, axis=0)
66+
den = jnp.sum(w_exp, axis=0) + EPS_FLOAT
67+
return jax.lax.stop_gradient(num / den)
68+
69+
5570
def base(
5671
jitter_generator: Callable,
5772
next_random_arg_fn: Callable,
5873
optim: optax.GradientTransformation,
5974
target_acceptance_rate: float,
6075
decay_rate: float,
76+
max_leapfrog_steps: int,
6177
) -> Tuple[Callable, Callable]:
6278
"""Maximizing the Change in the Estimator of the Expected Square criterion
6379
(trajectory length) and dual averaging procedure (step size) for the jittered
@@ -144,6 +160,8 @@ def compute_parameters(
144160
harmonic_mean = 1.0 / jnp.mean(
145161
1.0 / acceptance_probabilities, where=~is_divergent
146162
)
163+
# Replace inf/nan harmonic mean as zero to avoid issues in dual averaging
164+
harmonic_mean = jnp.where(jnp.isfinite(harmonic_mean), harmonic_mean, 0.0)
147165
da_state_ = da_update(da_state, target_acceptance_rate - harmonic_mean)
148166
step_size_ = jnp.exp(da_state_.log_x)
149167
new_step_size, new_da_state, new_log_step_size = jax.lax.cond(
@@ -157,9 +175,14 @@ def compute_parameters(
157175
1.0 - update_weight
158176
) * log_step_size_ma + update_weight * new_log_step_size
159177

178+
w = jnp.where(~is_divergent, acceptance_probabilities, 0.0)
160179
proposals_mean = jax.tree_util.tree_map(
161-
lambda p: jnp.nanmean(p, axis=0), proposed_positions
180+
lambda p: weighted_empirical_mean(p, w), proposed_positions
162181
)
182+
# The above weighted mean is presumably better than the simple mean:
183+
# proposals_mean = jax.tree_util.tree_map(
184+
# lambda p: jnp.nanmean(p, axis=0), proposed_positions
185+
# )
163186
initials_mean = jax.tree_util.tree_map(
164187
lambda p: jnp.nanmean(p, axis=0), initial_positions
165188
)
@@ -177,19 +200,25 @@ def compute_parameters(
177200

178201
trajectory_gradients = (
179202
jitter_generator(random_generator_arg)
180-
* trajectory_length
203+
* trajectory_length # this effectively make this gradient w.r.t. log_trajectory_length
181204
* jax.vmap(
182205
lambda pm, im, mm: (jnp.dot(pm, pm) - jnp.dot(im, im)) * jnp.dot(pm, mm)
183206
)(proposals_matrix, initials_matrix, momentums_matrix)
184207
)
208+
185209
trajectory_gradient = jnp.sum(
186-
acceptance_probabilities * trajectory_gradients, where=~is_divergent
187-
) / jnp.sum(acceptance_probabilities, where=~is_divergent)
210+
acceptance_probabilities * trajectory_gradients,
211+
where=~is_divergent,
212+
) / jnp.sum(acceptance_probabilities + EPS_FLOAT, where=~is_divergent)
188213

189214
log_trajectory_length = jnp.log(trajectory_length)
190215
updates, optim_state_ = optim.update(
191216
trajectory_gradient, optim_state, log_trajectory_length
192217
)
218+
219+
updates = jax.tree_util.tree_map(
220+
lambda u: jnp.clip(u, -LOG_UPDATE_CLIP, LOG_UPDATE_CLIP), updates
221+
)
193222
log_trajectory_length_ = optax.apply_updates(log_trajectory_length, updates)
194223
new_log_trajectory_length, new_optim_state = jax.lax.cond(
195224
jnp.isfinite(
@@ -204,6 +233,13 @@ def compute_parameters(
204233
) * log_trajectory_length_ma + update_weight * new_log_trajectory_length
205234
new_trajectory_length = jnp.exp(new_log_trajectory_length_ma)
206235

236+
# clip new trajectory length to avoid too large trajectories, also the
237+
# minimum trajectory length is one integrator step
238+
new_trajectory_length = jnp.clip(
239+
new_trajectory_length,
240+
max=max_leapfrog_steps * new_step_size,
241+
min=new_step_size,
242+
)
207243
return ChEESAdaptationState(
208244
new_step_size,
209245
new_log_step_size_ma,
@@ -278,6 +314,7 @@ def chees_adaptation(
278314
jitter_amount: float = 1.0,
279315
target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE,
280316
decay_rate: float = 0.5,
317+
max_leapfrog_steps: int = 1000,
281318
adaptation_info_fn: Callable = return_all_adapt_info,
282319
) -> AdaptationAlgorithm:
283320
"""Adapt the step size and trajectory length (number of integration steps / step size)
@@ -376,13 +413,14 @@ def run(
376413
jax.random.fold_in(carry_key, i)
377414
) * jitter_amount + (1.0 - jitter_amount)
378415
else:
416+
max_bits = np.ceil(np.log2(num_steps + max_sampling_steps))
379417
jitter_gn = lambda i: dynamic_hmc.halton_sequence(
380-
i, np.ceil(np.log2(num_steps + max_sampling_steps))
418+
i, max_bits
381419
) * jitter_amount + (1.0 - jitter_amount)
382420

383-
def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
421+
def integration_steps_fn(random_generator_arg, num_leapfrog_steps):
384422
return jnp.asarray(
385-
jnp.ceil(jitter_gn(random_generator_arg) * trajectory_length_adjusted),
423+
jnp.ceil(jitter_gn(random_generator_arg) * num_leapfrog_steps),
386424
dtype=int,
387425
)
388426

@@ -392,7 +430,12 @@ def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
392430
)
393431

394432
init, update = base(
395-
jitter_gn, next_random_arg_fn, optim, target_acceptance_rate, decay_rate
433+
jitter_gn,
434+
next_random_arg_fn,
435+
optim,
436+
target_acceptance_rate,
437+
decay_rate,
438+
max_leapfrog_steps,
396439
)
397440

398441
def one_step(carry, rng_key):
@@ -404,7 +447,7 @@ def one_step(carry, rng_key):
404447
logdensity_fn=logdensity_fn,
405448
step_size=adaptation_state.step_size,
406449
inverse_mass_matrix=jnp.ones(num_dim),
407-
trajectory_length_adjusted=adaptation_state.trajectory_length
450+
num_leapfrog_steps=adaptation_state.trajectory_length
408451
/ adaptation_state.step_size,
409452
)
410453
new_states, info = jax.vmap(_step_fn)(keys, states)
@@ -432,7 +475,7 @@ def one_step(carry, rng_key):
432475
one_step, (init_states, init_adaptation_state), keys_step
433476
)
434477

435-
trajectory_length_adjusted = jnp.exp(
478+
num_leapfrog_steps = jnp.exp(
436479
last_adaptation_state.log_trajectory_length_moving_average
437480
- last_adaptation_state.log_step_size_moving_average
438481
)
@@ -441,7 +484,7 @@ def one_step(carry, rng_key):
441484
"inverse_mass_matrix": jnp.ones(num_dim),
442485
"next_random_arg_fn": next_random_arg_fn,
443486
"integration_steps_fn": lambda arg: integration_steps_fn(
444-
arg, trajectory_length_adjusted
487+
arg, num_leapfrog_steps
445488
),
446489
}
447490

blackjax/mcmc/dynamic_hmc.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Public API for the Dynamic HMC Kernel"""
15+
1516
from typing import Callable, NamedTuple
1617

1718
import jax
@@ -46,7 +47,11 @@ class DynamicHMCState(NamedTuple):
4647
random_generator_arg: Array
4748

4849

49-
def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array):
50+
def init(
51+
position: ArrayLikeTree,
52+
logdensity_fn: Callable,
53+
random_generator_arg: Array,
54+
):
5055
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
5156
return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg)
5257

@@ -154,7 +159,10 @@ def as_top_level_api(
154159
A ``SamplingAlgorithm``.
155160
"""
156161
kernel = build_kernel(
157-
integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn
162+
integrator,
163+
divergence_threshold,
164+
next_random_arg_fn,
165+
integration_steps_fn,
158166
)
159167

160168
def init_fn(position: ArrayLikeTree, rng_key: Array):
@@ -176,6 +184,14 @@ def step_fn(rng_key: PRNGKey, state):
176184

177185

178186
def halton_sequence(i: Array, max_bits: int = 10) -> float:
187+
"""Generate the (i+1)-th element of the Halton sequence.
188+
189+
Warning: max_bits should be less than the bit width of i.dtype to prevent integer overflow (e.g., max_bits <= 63 for int64).
190+
"""
191+
if max_bits >= jnp.iinfo(i.dtype).bits:
192+
raise ValueError(
193+
f"max_bits ({max_bits}) must be less than bit width of dtype {i.dtype} ({jnp.iinfo(i.dtype).bits})"
194+
)
179195
bit_masks = 2 ** jnp.arange(max_bits, dtype=i.dtype)
180196
return jnp.einsum("i,i->", jnp.mod((i + 1) // bit_masks, 2), 0.5 / bit_masks)
181197

tests/adaptation/test_adaptation.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,24 @@ def test_adaptation_schedule(num_steps, expected_schedule):
6161
],
6262
)
6363
def test_chees_adaptation(adaptation_filters):
64+
target_mean = jnp.array([0.0, 0.0])
65+
target_std = jnp.array([1.0, 10.0])
6466
logprob_fn = lambda x: jax.scipy.stats.norm.logpdf(
65-
x, loc=0.0, scale=jnp.array([1.0, 10.0])
66-
).sum()
67+
x, loc=target_mean, scale=target_std
68+
).sum(axis=-1)
6769

6870
num_burnin_steps = 1000
6971
num_results = 500
7072
num_chains = 16
7173
step_size = 0.1
74+
target_acceptance_rate = 0.75
7275

7376
init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3)
7477

7578
warmup = blackjax.chees_adaptation(
7679
logprob_fn,
7780
num_chains=num_chains,
78-
target_acceptance_rate=0.75,
81+
target_acceptance_rate=target_acceptance_rate,
7982
adaptation_info_fn=adaptation_filters["filter_fn"],
8083
)
8184

@@ -84,13 +87,12 @@ def test_chees_adaptation(adaptation_filters):
8487
warmup_key,
8588
initial_positions,
8689
step_size=step_size,
87-
optim=optax.adamw(learning_rate=0.5),
90+
optim=optax.adam(learning_rate=0.5, b1=0, b2=0.95),
8891
num_steps=num_burnin_steps,
8992
)
9093
algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)
91-
9294
chain_keys = jax.random.split(inference_key, num_chains)
93-
_, (_, infos) = jax.vmap(
95+
final_states, (states, infos) = jax.vmap(
9496
lambda key, state: run_inference_algorithm(
9597
rng_key=key,
9698
initial_state=state,
@@ -99,7 +101,9 @@ def test_chees_adaptation(adaptation_filters):
99101
)
100102
)(chain_keys, last_states)
101103

102-
harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
104+
harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate, axis=0)
105+
assert harmonic_mean.shape == (num_results,)
106+
harmonic_mean = jnp.mean(harmonic_mean)
103107

104108
def check_attrs(attribute, keyset):
105109
for name, param in getattr(warmup_info, attribute)._asdict().items():
@@ -119,6 +123,24 @@ def check_attrs(attribute, keyset):
119123
for i, attribute in enumerate(["state", "info", "adaptation_state"]):
120124
check_attrs(attribute, keysets[i])
121125

122-
np.testing.assert_allclose(harmonic_mean, 0.75, atol=1e-1)
123-
np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1)
124-
np.testing.assert_array_less(infos.num_integration_steps.mean(), 15.0)
126+
# The harmonic mean of the acceptance rate should be close to the target acceptance rate
127+
np.testing.assert_allclose(harmonic_mean, target_acceptance_rate, atol=1e-1)
128+
129+
# These are empirical values that should be roughly correct for this target distribution
130+
np.testing.assert_allclose(parameters["step_size"], 1.5, atol=0.3)
131+
np.testing.assert_allclose(infos.num_integration_steps.mean(), 9, atol=3)
132+
133+
# Check that sample means and stds are close to target values
134+
draws = states.position.reshape(-1, states.position.shape[-1])
135+
empirical_mean = jnp.mean(draws, axis=0)
136+
empirical_std = jnp.std(draws, axis=0)
137+
np.testing.assert_allclose(empirical_mean, target_mean, atol=0.5)
138+
np.testing.assert_allclose(empirical_std, target_std, rtol=0.1)
139+
140+
141+
def test_halton_sequence_raise_value():
142+
"""Test that halton sequence raises value error when max_bits is too large."""
143+
from blackjax.mcmc.dynamic_hmc import halton_sequence
144+
145+
with pytest.raises(ValueError, match="max_bits"):
146+
halton_sequence(jnp.array([0], dtype=jnp.int32), max_bits=32)

0 commit comments

Comments
 (0)