Skip to content

Commit 7a30b22

Browse files
committed
Add args to TPS MD baseline
1 parent fa175ca commit 7a30b22

File tree

2 files changed

+61
-37
lines changed

2 files changed

+61
-37
lines changed

tps/second_order.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@ def __init__(self, start_state, target_state, step_forward, step_backward, sampl
1515
self.sample_velocity = sample_velocity
1616

1717

18-
def one_way_shooting(system, trajectory, fixed_length, key):
19-
raise NotImplementedError('Not implemented for multi step systems')
18+
def one_way_shooting(system, trajectory, fixed_length, dt, key):
2019
key = jax.random.split(key)
2120

2221
# pick a random point along the trajectory
2322
point_idx = jax.random.randint(key[0], (1,), 1, len(trajectory) - 1)[0]
2423
# pick a random direction, either forward or backward
2524
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
2625

27-
# TODO: Fix correct dt in ps / pass previous velocities
28-
new_velocities = [(trajectory[point_idx] - trajectory[point_idx - 1]) / 0.001]
26+
new_velocities = [(trajectory[point_idx] - trajectory[point_idx - 1]) / dt]
2927

3028
if direction == 0:
3129
trajectory = trajectory[:point_idx + 1]
@@ -40,30 +38,42 @@ def one_way_shooting(system, trajectory, fixed_length, key):
4038
while len(trajectory) < steps:
4139
key, iter_key = jax.random.split(key)
4240
point, velocity = step_function(trajectory[-1], new_velocities[-1], iter_key)
43-
trajectory.append(point)
44-
new_velocities.append(velocity)
4541

46-
if jnp.isnan(point).any() or jnp.isnan(velocity).any():
47-
return False, trajectory, new_velocities
42+
nan_filter = jnp.isnan(point).any(axis=-1).flatten() | jnp.isnan(velocity).any(axis=-1).flatten()
43+
too_big_filter = (jnp.abs(point) > MAX_ABS_VALUE).any(axis=-1).flatten()
4844

49-
# ensure that our trajectory does not explode
50-
if (jnp.abs(point) > MAX_ABS_VALUE).any():
51-
return False, trajectory, new_velocities
45+
start_state_filter = system.start_state(point)
46+
target_state_filter = system.target_state(point)
5247

53-
if system.start_state(trajectory[0]) and system.target_state(trajectory[-1]):
54-
if fixed_length == 0 or len(trajectory) == fixed_length:
55-
return True, trajectory, new_velocities
56-
return False, trajectory, new_velocities
48+
all_filters_combined = start_state_filter | target_state_filter | nan_filter | too_big_filter
49+
50+
limit = jnp.argmax(all_filters_combined) + 1 if all_filters_combined.any() else len(all_filters_combined)
51+
trajectory.extend(point[:limit])
52+
new_velocities.extend(velocity[:limit])
5753

58-
if system.target_state(trajectory[0]) and system.start_state(trajectory[-1]):
59-
if fixed_length == 0 or len(trajectory) == fixed_length:
60-
return True, trajectory[::-1], new_velocities[::-1]
54+
if (nan_filter | too_big_filter)[:limit].any():
6155
return False, trajectory, new_velocities
6256

57+
if (start_state_filter | target_state_filter)[:limit].any():
58+
break
59+
60+
# throw away the trajectory if it's not the right length
61+
if len(trajectory) > steps:
62+
return False, trajectory[:steps], new_velocities[:steps]
63+
64+
if fixed_length != 0 and len(trajectory) != fixed_length:
65+
return False, trajectory, new_velocities
66+
67+
if system.start_state(trajectory[0]) and system.target_state(trajectory[-1]):
68+
return True, trajectory, new_velocities
69+
70+
if system.target_state(trajectory[0]) and system.start_state(trajectory[-1]):
71+
return True, trajectory[::-1], new_velocities[::-1]
72+
6373
return False, trajectory, new_velocities
6474

6575

66-
def two_way_shooting(system, trajectory, fixed_length, key):
76+
def two_way_shooting(system, trajectory, fixed_length, _dt, key):
6777
key = jax.random.split(key)
6878

6979
# pick a random point along the trajectory
@@ -139,7 +149,7 @@ def two_way_shooting(system, trajectory, fixed_length, key):
139149
return False, new_trajectory, new_velocities
140150

141151

142-
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_length=0, warmup=50):
152+
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixed_length=0, warmup=50):
143153
# pick an initial trajectory
144154
trajectories = [initial_trajectory]
145155
velocities = []
@@ -162,12 +172,12 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
162172
if len(trajectories) > warmup:
163173
pbar.set_description('')
164174

165-
key, traj_idx_key, iter_key, accept_key = jax.random.split(key, 4)
175+
key, traj_idx_key, ikey, accept_key = jax.random.split(key, 4)
166176
traj_idx = jax.random.randint(traj_idx_key, (1,), warmup + 1, len(trajectories))[0]
167177
# during warmup, we want an iterative scheme
168178
traj_idx = traj_idx if traj_idx < len(trajectories) else -1
169179

170-
found, new_trajectory, new_velocities = proposal(system, trajectories[traj_idx], fixed_length, iter_key)
180+
found, new_trajectory, new_velocities = proposal(system, trajectories[traj_idx], fixed_length, dt, ikey)
171181
statistics['num_force_evaluations'] += len(new_trajectory) - 1
172182

173183
if not found:

tps_baseline.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
from utils.animation import save_trajectory, to_md_traj
2828
from utils.rmsd import kabsch_align, kabsch_rmsd
2929

30+
from argparse import ArgumentParser
31+
32+
parser = ArgumentParser()
33+
parser.add_argument('--mechanism', type=str, choices=['one-way-shooting', 'two-way-shooting'], required=True)
34+
parser.add_argument('--fixed_length', type=int, default=0)
35+
parser.add_argument('--num_steps', type=int, default=10,
36+
help='The number of MD steps taken at once. More takes longer to compile but runs faster in the end.')
37+
3038

3139
def human_format(num):
3240
"""https://stackoverflow.com/a/45846841/4417954"""
@@ -121,12 +129,14 @@ def step_n(step, _x, _v, n, _key):
121129

122130

123131
if __name__ == '__main__':
132+
args = parser.parse_args()
133+
124134
init_pdb = app.PDBFile("./files/AD_A.pdb")
125135
target_pdb = app.PDBFile("./files/AD_B.pdb")
126136
mdtraj_topology = md.Topology.from_openmm(init_pdb.topology)
127137
phis_psis = phi_psi_from_mdtraj(mdtraj_topology)
128138

129-
savedir = f"out/baselines/alanine"
139+
savedir = f"out/baselines/alanine-{args.mechanism}"
130140
os.makedirs(savedir, exist_ok=True)
131141

132142
# Construct the mass matrix
@@ -171,13 +181,8 @@ def U(_x):
171181

172182
@jax.jit
173183
@jax.vmap
174-
def dUdx_fn_unscaled(_x):
175-
return jax.grad(lambda _x: U(_x).sum())(_x)
176-
177-
178-
@jax.jit
179184
def dUdx_fn(_x):
180-
return dUdx_fn_unscaled(_x) / mass / gamma
185+
return jax.grad(lambda _x: U(_x).sum())(_x) / mass / gamma_in_ps
181186

182187

183188
@jax.jit
@@ -188,10 +193,10 @@ def step(_x, _key):
188193

189194
@jax.jit
190195
def step_langevin_forward(_x, _v, _key):
191-
"""Perform one step of forward langevin"""
196+
"""Perform one step of forward langevin as implemented in openmm"""
192197
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
193198
f_scale = (1 - alpha) / gamma_in_ps
194-
new_v_det = alpha * _v + f_scale * -dUdx_fn_unscaled(_x) / mass
199+
new_v_det = alpha * _v + f_scale * -dUdx_fn(_x)
195200
new_v = new_v_det + jnp.sqrt(kbT * (1 - alpha ** 2) / mass) * jax.random.normal(_key, _x.shape)
196201

197202
return _x + dt_in_ps * new_v, new_v
@@ -201,7 +206,7 @@ def step_langevin_forward(_x, _v, _key):
201206
def step_langevin_log_density(_x, _v, _new_x, _new_v):
202207
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
203208
f_scale = (1 - alpha) / gamma_in_ps
204-
new_v_det = alpha * _v + f_scale * -dUdx_fn_unscaled(_x) / mass
209+
new_v_det = alpha * _v + f_scale * -dUdx_fn(_x)
205210
new_v_rand = new_v_det - _new_v
206211

207212
return jax.scipy.stats.norm.logpdf(new_v_rand, 0, jnp.sqrt(kbT * (1 - alpha ** 2) / mass)).sum()
@@ -225,7 +230,7 @@ def step_langevin_backward(_x, _v, _key):
225230
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
226231
f_scale = (1 - alpha) / gamma_in_ps
227232
prev_x = _x - dt_in_ps * _v
228-
prev_v = 1 / alpha * (_v + f_scale * dUdx_fn_unscaled(prev_x) / mass - jnp.sqrt(
233+
prev_v = 1 / alpha * (_v + f_scale * dUdx_fn(prev_x) - jnp.sqrt(
229234
kbT * (1 - alpha ** 2) / mass) * jax.random.normal(_key, _x.shape))
230235

231236
return prev_x, prev_v
@@ -281,8 +286,8 @@ def step_langevin_backward(_x, _v, _key):
281286
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
282287
# step_langevin_forward,
283288
# step_langevin_backward,
284-
jax.jit(lambda _x, _v, _key: step_n(step_langevin_forward, _x, _v, 40, _key)),
285-
jax.jit(lambda _x, _v, _key: step_n(step_langevin_backward, _x, _v, 40, _key)),
289+
jax.jit(lambda _x, _v, _key: step_n(step_langevin_forward, _x, _v, args.num_steps, _key)),
290+
jax.jit(lambda _x, _v, _key: step_n(step_langevin_backward, _x, _v, args.num_steps, _key)),
286291
jax.jit(lambda key: jnp.sqrt(kbT / mass) * jax.random.normal(key, (1, 66)))
287292
)
288293

@@ -311,9 +316,17 @@ def step_langevin_backward(_x, _v, _key):
311316
with open(f'{savedir}/stats.json', 'r') as fp:
312317
statistics = json.load(fp)
313318
else:
319+
if args.mechanism == 'one-way-shooting':
320+
mechanism = tps2.one_way_shooting
321+
elif args.mechanism == 'two-way-shooting':
322+
mechanism = tps2.two_way_shooting
323+
else:
324+
raise ValueError(f"Unknown mechanism {args.mechanism}")
325+
314326
try:
315-
paths, velocities, statistics = tps2.mcmc_shooting(system, tps2.two_way_shooting, initial_trajectory,
316-
100, jax.random.PRNGKey(1), warmup=0, fixed_length=1000)
327+
paths, velocities, statistics = tps2.mcmc_shooting(system, mechanism, initial_trajectory,
328+
100, dt_in_ps, jax.random.PRNGKey(1), warmup=0,
329+
fixed_length=args.fixed_length)
317330
# paths = tps2.unguided_md(system, B, 1, key)
318331
paths = [jnp.array(p) for p in paths]
319332
velocities = [jnp.array(p) for p in velocities]
@@ -324,6 +337,7 @@ def step_langevin_backward(_x, _v, _key):
324337
with open(f'{savedir}/stats.json', 'w') as fp:
325338
json.dump(statistics, fp)
326339
except Exception as e:
340+
print(e)
327341
breakpoint()
328342

329343
print(statistics)

0 commit comments

Comments
 (0)