Skip to content

Commit 2c5080d

Browse files
committed
Add --num_paths and --states
1 parent cc016e2 commit 2c5080d

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

tps_baseline.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from functools import partial
3-
3+
import traceback
44
import jax
55
import numpy as np
66
import matplotlib.pyplot as plt
@@ -31,7 +31,9 @@
3131

3232
parser = ArgumentParser()
3333
parser.add_argument('--mechanism', type=str, choices=['one-way-shooting', 'two-way-shooting'], required=True)
34+
parser.add_argument('--states', type=str, default='phi-psi', choices=['phi-psi', 'rmsd'])
3435
parser.add_argument('--fixed_length', type=int, default=0)
36+
parser.add_argument('--num_paths', type=int, required=True)
3537
parser.add_argument('--num_steps', type=int, default=10,
3638
help='The number of MD steps taken at once. More takes longer to compile but runs faster in the end.')
3739
parser.add_argument('--resume', action='store_true')
@@ -138,7 +140,12 @@ def step_n(step, _x, _v, n, _key):
138140
mdtraj_topology = md.Topology.from_openmm(init_pdb.topology)
139141
phis_psis = phi_psi_from_mdtraj(mdtraj_topology)
140142

141-
savedir = f"out/baselines/alanine-{args.mechanism}-{args.fixed_length}"
143+
savedir = f"out/baselines/alanine-{args.mechanism}"
144+
if args.fixed_length > 0:
145+
savedir += f'-{args.fixed_length}steps'
146+
if args.states == 'rmsd':
147+
savedir += '-rmsd'
148+
142149
os.makedirs(savedir, exist_ok=True)
143150

144151
# Construct the mass matrix
@@ -253,11 +260,17 @@ def step_langevin_backward(_x, _v, _key):
253260
# step
254261
# )
255262

263+
if args.states == 'rmsd':
264+
state_A = jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2))
265+
state_B = jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2))
266+
elif args.states == 'phi-psi':
267+
state_A = jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius))
268+
state_B = jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius))
269+
else:
270+
raise ValueError(f"Unknown states {args.states}")
271+
256272
system = tps2.SecondOrderSystem(
257-
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
258-
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
259-
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
260-
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
273+
state_A, state_B,
261274
jax.jit(lambda _x, _v, _key: step_n(step_langevin_forward, _x, _v, args.num_steps, _key)),
262275
jax.jit(lambda _x, _v, _key: step_n(step_langevin_backward, _x, _v, args.num_steps, _key)),
263276
jax.jit(lambda key: jnp.sqrt(kbT / mass) * jax.random.normal(key, (1, 66)))
@@ -269,7 +282,8 @@ def step_langevin_backward(_x, _v, _key):
269282

270283
if args.resume:
271284
paths = [[x for x in p.astype(np.float32)] for p in np.load(f'{savedir}/paths.npy', allow_pickle=True)]
272-
velocities = [[v for v in p.astype(np.float32)] for p in np.load(f'{savedir}/velocities.npy', allow_pickle=True)]
285+
velocities = [[v for v in p.astype(np.float32)] for p in
286+
np.load(f'{savedir}/velocities.npy', allow_pickle=True)]
273287
with open(f'{savedir}/stats.json', 'r') as fp:
274288
statistics = json.load(fp)
275289

@@ -295,7 +309,7 @@ def step_langevin_backward(_x, _v, _key):
295309

296310
try:
297311
paths, velocities, statistics = tps2.mcmc_shooting(system, mechanism, initial_trajectory,
298-
100, dt_in_ps, jax.random.PRNGKey(1), warmup=0,
312+
args.num_paths, dt_in_ps, jax.random.PRNGKey(1), warmup=0,
299313
fixed_length=args.fixed_length,
300314
stored=stored)
301315
# paths = tps2.unguided_md(system, B, 1, key)
@@ -308,13 +322,16 @@ def step_langevin_backward(_x, _v, _key):
308322
with open(f'{savedir}/stats.json', 'w') as fp:
309323
json.dump(statistics, fp)
310324
except Exception as e:
311-
print(e)
325+
print(traceback.format_exc())
312326
breakpoint()
313327

314328
print(statistics)
315-
print([len(p) for p in paths])
316-
plt.hist([len(p) for p in paths], bins=jnp.sqrt(len(paths)).astype(int).item())
317-
plt.show()
329+
330+
if args.fixed_length == 0:
331+
print([len(p) for p in paths])
332+
plt.hist([len(p) for p in paths], bins=jnp.sqrt(len(paths)).astype(int).item())
333+
plt.savefig(f'{savedir}/lengths.png', bbox_inches='tight')
334+
plt.show()
318335

319336
path_hist = PeriodicPathHistogram()
320337
for i, path in tqdm(enumerate(paths), desc='Adding paths to histogram', total=len(paths)):

0 commit comments

Comments
 (0)