Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 9e5b91c

Browse files
committed
Skip trajectory building if first step diverges
1 parent e6498d4 commit 9e5b91c

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

aehmc/proposals.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ def progressive_uniform_sampling(
8585
state, energy, weight, _ = proposal
8686
new_state, new_energy, new_weight, _ = new_proposal
8787

88+
# TODO: Make the `at.isnan` check unnecessary
8889
p_accept = at.expit(new_weight - weight)
90+
p_accept = at.where(at.isnan(p_accept), 0, p_accept)
91+
8992
do_accept = srng.bernoulli(p_accept)
9093
updated_proposal = maybe_update_proposal(do_accept, proposal, new_proposal)
9194

aehmc/trajectory.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,21 @@ def add_one_state(
266266
state[1],
267267
0,
268268
)
269+
full_initial_state = (
270+
*proposal[0],
271+
proposal[1],
272+
proposal[2],
273+
proposal[3],
274+
*state,
275+
momentum_sum,
276+
*termination_state,
277+
at.as_tensor(1, dtype=np.int32),
278+
is_diverging,
279+
np.array(False),
280+
)
269281

270282
steps = at.arange(1, 1 + max_num_steps)
271-
traj, updates = aesara.scan(
283+
trajectory, updates = aesara.scan(
272284
add_one_state,
273285
outputs_info=(
274286
*proposal[0],
@@ -284,19 +296,28 @@ def add_one_state(
284296
),
285297
sequences=steps,
286298
)
299+
full_last_state = tuple([state[-1] for state in trajectory])
300+
301+
# We build the trajectory iff the first step is not diverging
302+
full_state = ifelse(is_diverging, full_initial_state, full_last_state)
287303

288304
new_proposal = (
289-
(traj[0][-1], traj[1][-1], traj[2][-1], traj[3][-1]),
290-
traj[4][-1],
291-
traj[5][-1],
292-
traj[6][-1],
305+
(full_state[0], full_state[1], full_state[2], full_state[3]),
306+
full_state[4],
307+
full_state[5],
308+
full_state[6],
309+
)
310+
new_state = (full_state[7], full_state[8], full_state[9], full_state[10])
311+
subtree_momentum_sum = full_state[11]
312+
new_termination_state = (
313+
full_state[12],
314+
full_state[13],
315+
full_state[14],
316+
full_state[15],
293317
)
294-
new_state = (traj[7][-1], traj[8][-1], traj[9][-1], traj[10][-1])
295-
subtree_momentum_sum = traj[11][-1]
296-
new_termination_state = (traj[12][-1], traj[13][-1], traj[14][-1], traj[15][-1])
297-
trajectory_length = traj[-3][-1] # defined as the number of steps taken
298-
is_diverging = traj[-2][-1] | is_diverging
299-
has_terminated = traj[-1][-1]
318+
trajectory_length = full_state[-3]
319+
is_diverging = full_state[-2]
320+
has_terminated = full_state[-1]
300321

301322
return (
302323
new_proposal,

tests/test_trajectory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def test_static_integration(example):
7171
"case",
7272
[
7373
(0.0000001, False, False),
74-
(1000, True, True),
74+
(1000, True, False),
75+
(1e100, True, False),
7576
],
7677
)
7778
def test_dynamic_integration(case):
@@ -136,7 +137,7 @@ def potential_fn(x):
136137
@pytest.mark.parametrize(
137138
"step_size, should_diverge, should_turn, expected_doublings",
138139
[
139-
(100000.0, True, True, 1),
140+
(100000.0, True, False, 1),
140141
(0.0000001, False, False, 10),
141142
(1.0, False, True, 1),
142143
],

0 commit comments

Comments
 (0)