Skip to content

Commit 658d924

Browse files
committed
Move path histogram plotting to end of inference
1 parent 358c7e0 commit 658d924

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

main.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,27 +215,25 @@ def main():
215215
# In case we have a second order integration scheme, we remove the velocity for plotting
216216
x_t_det_no_vel = x_t_det[:, :, :system.A.shape[0]]
217217

218+
key, path_key = jax.random.split(key)
219+
x_t_stoch = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, path_key)
220+
x_t_stoch_no_vel = x_t_stoch[:, :, :system.A.shape[0]]
221+
np.save(f'{args.save_dir}/stochastic_paths.npy', x_t_stoch_no_vel)
222+
218223
if system.mdtraj_topology:
219224
save_trajectory(system.mdtraj_topology, x_t_det_no_vel[0].reshape(1, -1, 3), f'{args.save_dir}/det_0.pdb')
220225
save_trajectory(system.mdtraj_topology, x_t_det_no_vel[-1].reshape(1, -1, 3), f'{args.save_dir}/det_-1.pdb')
221226

227+
save_trajectory(system.mdtraj_topology, x_t_stoch_no_vel[0].reshape(1, -1, 3), f'{args.save_dir}/stoch_0.pdb')
228+
save_trajectory(system.mdtraj_topology, x_t_stoch_no_vel[-1].reshape(1, -1, 3), f'{args.save_dir}/stoch_-1.pdb')
229+
222230
if system.plot:
223231
plot_energy(system, [x_t_det_no_vel[0], x_t_det_no_vel[-1]], args.log_plots)
224232
show_or_save_fig(args.save_dir, 'path_energy_deterministic', args.extension)
225233

226234
system.plot(title='Deterministic Paths', trajectories=x_t_det_no_vel)
227235
show_or_save_fig(args.save_dir, 'paths_deterministic', args.extension)
228236

229-
key, path_key = jax.random.split(key)
230-
x_t_stoch = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, path_key)
231-
x_t_stoch_no_vel = x_t_stoch[:, :, :system.A.shape[0]]
232-
np.save(f'{args.save_dir}/stochastic_paths.npy', x_t_stoch_no_vel)
233-
234-
if system.mdtraj_topology:
235-
save_trajectory(system.mdtraj_topology, x_t_stoch_no_vel[0].reshape(1, -1, 3), f'{args.save_dir}/stoch_0.pdb')
236-
save_trajectory(system.mdtraj_topology, x_t_stoch_no_vel[-1].reshape(1, -1, 3), f'{args.save_dir}/stoch_-1.pdb')
237-
238-
if system.plot:
239237
plot_energy(system, [x_t_stoch_no_vel[0], x_t_stoch_no_vel[-1]], args.log_plots)
240238
show_or_save_fig(args.save_dir, 'path_energy_stochastic', args.extension)
241239

0 commit comments

Comments
 (0)