@@ -137,31 +137,37 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
137
137
if fixed_length > 0 :
138
138
statistics ['fixed_length' ] = fixed_length
139
139
140
- with tqdm (total = num_paths + warmup , desc = 'warming up' if warmup > 0 else '' ) as pbar :
141
- while len (trajectories ) <= num_paths + warmup :
142
- statistics ['num_tries' ] += 1
143
- if len (trajectories ) > warmup :
144
- pbar .set_description ('' )
145
-
146
- key , traj_idx_key , iter_key , accept_key = jax .random .split (key , 4 )
147
- traj_idx = jax .random .randint (traj_idx_key , (1 ,), warmup + 1 , len (trajectories ))[0 ]
148
- # during warmup, we want an iterative scheme
149
- traj_idx = traj_idx if traj_idx < len (trajectories ) else - 1
150
-
151
- found , new_trajectory , new_velocities = proposal (system , trajectories [traj_idx ], fixed_length , iter_key )
152
- statistics ['num_force_evaluations' ] += len (new_trajectory ) - 1
153
-
154
- if not found :
155
- continue
156
-
157
- ratio = len (trajectories [- 1 ]) / len (new_trajectory )
158
- # The first trajectory might have a very unreasonable length, so we skip it
159
- if len (trajectories ) == 1 or jax .random .uniform (accept_key , shape = (1 ,)) < ratio :
160
- trajectories .append (new_trajectory )
161
- velocities .append (new_velocities )
162
- pbar .update (1 )
163
- else :
164
- statistics ['num_metropolis_rejected' ] += 1
140
+ try :
141
+ with tqdm (total = num_paths + warmup , desc = 'warming up' if warmup > 0 else '' ) as pbar :
142
+ while len (trajectories ) <= num_paths + warmup :
143
+ statistics ['num_tries' ] += 1
144
+ if len (trajectories ) > warmup :
145
+ pbar .set_description ('' )
146
+
147
+ key , traj_idx_key , iter_key , accept_key = jax .random .split (key , 4 )
148
+ traj_idx = jax .random .randint (traj_idx_key , (1 ,), warmup + 1 , len (trajectories ))[0 ]
149
+ # during warmup, we want an iterative scheme
150
+ traj_idx = traj_idx if traj_idx < len (trajectories ) else - 1
151
+
152
+ found , new_trajectory , new_velocities = proposal (system , trajectories [traj_idx ], fixed_length , iter_key )
153
+ statistics ['num_force_evaluations' ] += len (new_trajectory ) - 1
154
+
155
+ if not found :
156
+ continue
157
+
158
+ ratio = len (trajectories [- 1 ]) / len (new_trajectory )
159
+ # The first trajectory might have a very unreasonable length, so we skip it
160
+ if len (trajectories ) == 1 or jax .random .uniform (accept_key , shape = (1 ,)) < ratio :
161
+ trajectories .append (new_trajectory )
162
+ velocities .append (new_velocities )
163
+ pbar .update (1 )
164
+ else :
165
+ statistics ['num_metropolis_rejected' ] += 1
166
+ except KeyboardInterrupt :
167
+ print ('SIGINT received, stopping early' )
168
+ # Fix in case we stop when adding a trajectory
169
+ if len (trajectories ) > len (velocities ):
170
+ velocities .append (new_velocities )
165
171
166
172
return trajectories [warmup + 1 :], velocities [warmup :], statistics
167
173
0 commit comments