Skip to content

Commit 374f50d

Browse files
committed
Add control c handler to second order tps and improve logs
1 parent bc65c38 commit 374f50d

File tree

4 files changed

+31
-25
lines changed

4 files changed

+31
-25
lines changed

evaluate_mueller.py

Whitespace-only changes.

mueller.py

Whitespace-only changes.

tps/second_order.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,31 +137,37 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
137137
if fixed_length > 0:
138138
statistics['fixed_length'] = fixed_length
139139

140-
with tqdm(total=num_paths + warmup, desc='warming up' if warmup > 0 else '') as pbar:
141-
while len(trajectories) <= num_paths + warmup:
142-
statistics['num_tries'] += 1
143-
if len(trajectories) > warmup:
144-
pbar.set_description('')
145-
146-
key, traj_idx_key, iter_key, accept_key = jax.random.split(key, 4)
147-
traj_idx = jax.random.randint(traj_idx_key, (1,), warmup + 1, len(trajectories))[0]
148-
# during warmup, we want an iterative scheme
149-
traj_idx = traj_idx if traj_idx < len(trajectories) else -1
150-
151-
found, new_trajectory, new_velocities = proposal(system, trajectories[traj_idx], fixed_length, iter_key)
152-
statistics['num_force_evaluations'] += len(new_trajectory) - 1
153-
154-
if not found:
155-
continue
156-
157-
ratio = len(trajectories[-1]) / len(new_trajectory)
158-
# The first trajectory might have a very unreasonable length, so we skip it
159-
if len(trajectories) == 1 or jax.random.uniform(accept_key, shape=(1,)) < ratio:
160-
trajectories.append(new_trajectory)
161-
velocities.append(new_velocities)
162-
pbar.update(1)
163-
else:
164-
statistics['num_metropolis_rejected'] += 1
140+
try:
141+
with tqdm(total=num_paths + warmup, desc='warming up' if warmup > 0 else '') as pbar:
142+
while len(trajectories) <= num_paths + warmup:
143+
statistics['num_tries'] += 1
144+
if len(trajectories) > warmup:
145+
pbar.set_description('')
146+
147+
key, traj_idx_key, iter_key, accept_key = jax.random.split(key, 4)
148+
traj_idx = jax.random.randint(traj_idx_key, (1,), warmup + 1, len(trajectories))[0]
149+
# during warmup, we want an iterative scheme
150+
traj_idx = traj_idx if traj_idx < len(trajectories) else -1
151+
152+
found, new_trajectory, new_velocities = proposal(system, trajectories[traj_idx], fixed_length, iter_key)
153+
statistics['num_force_evaluations'] += len(new_trajectory) - 1
154+
155+
if not found:
156+
continue
157+
158+
ratio = len(trajectories[-1]) / len(new_trajectory)
159+
# The first trajectory might have a very unreasonable length, so we skip it
160+
if len(trajectories) == 1 or jax.random.uniform(accept_key, shape=(1,)) < ratio:
161+
trajectories.append(new_trajectory)
162+
velocities.append(new_velocities)
163+
pbar.update(1)
164+
else:
165+
statistics['num_metropolis_rejected'] += 1
166+
except KeyboardInterrupt:
167+
print('SIGINT received, stopping early')
168+
# Fix in case we stop when adding a trajectory
169+
if len(trajectories) > len(velocities):
170+
velocities.append(new_velocities)
165171

166172
return trajectories[warmup + 1:], velocities[warmup:], statistics
167173

tps_baseline_mueller.py

Whitespace-only changes.

0 commit comments

Comments
 (0)