Skip to content

Commit fb71589

Browse files
committed
Add some TPS code and baselines
1 parent 52c5812 commit fb71589

File tree

4 files changed

+391
-3
lines changed

4 files changed

+391
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
models_variational_gp/
2+
baselines/
23

34
# Created by https://www.toptal.com/developers/gitignore/api/macos,python,pycharm+all
45
# Edit at https://www.toptal.com/developers/gitignore?templates=macos,python,pycharm+all

tps.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from tqdm import tqdm
4+
5+
MAX_STEPS = 1_000
6+
7+
8+
class System:
9+
def __init__(self, start_state, target_state, step):
10+
self.start_state = start_state
11+
self.target_state = target_state
12+
self.step = step
13+
14+
15+
def one_way_shooting(system, trajectory, fixed_length, key):
16+
key = jax.random.split(key)
17+
18+
# pick a random point along the trajectory
19+
point_idx = jax.random.randint(key[0], (1,), 1, len(trajectory) - 1)[0]
20+
# pick a random direction, either forward or backward
21+
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
22+
23+
print(f'testing path with id={point_idx} and direction={"forward" if direction == 0 else "backward"}')
24+
25+
if direction == 0:
26+
trajectory = trajectory[:point_idx + 1]
27+
28+
if direction == 1:
29+
trajectory = trajectory[point_idx:][::-1]
30+
31+
steps = MAX_STEPS if fixed_length == 0 else fixed_length
32+
33+
key, iter_key = jax.random.split(key[2])
34+
while len(trajectory) < steps:
35+
key, iter_key = jax.random.split(key)
36+
point = system.step(trajectory[-1], iter_key)
37+
trajectory.append(point)
38+
39+
if jnp.isnan(point).any():
40+
print('Nan detected!!!')
41+
return False, trajectory
42+
43+
if system.start_state(trajectory[0]) and system.target_state(trajectory[-1]):
44+
if fixed_length == 0 or len(trajectory) == fixed_length:
45+
return True, trajectory
46+
return False, trajectory
47+
48+
if system.target_state(trajectory[0]) and system.start_state(trajectory[-1]):
49+
if fixed_length == 0 or len(trajectory) == fixed_length:
50+
return True, trajectory[::-1]
51+
return False, trajectory
52+
53+
return False, trajectory
54+
55+
56+
def two_way_shooting(system, trajectory, fixed_length, key):
57+
key = jax.random.split(key)
58+
59+
# pick a random point along the trajectory
60+
point_idx = jax.random.randint(key[0], (1,), 1, len(trajectory) - 1)[0]
61+
point = trajectory[point_idx]
62+
# simulate forward from the point until max_steps
63+
64+
steps = MAX_STEPS if fixed_length == 0 else fixed_length
65+
66+
key, iter_key = jax.random.split(key[1])
67+
new_trajectory = [point]
68+
while len(new_trajectory) < steps:
69+
key, iter_key = jax.random.split(key)
70+
point = system.step(new_trajectory[-1], iter_key)
71+
new_trajectory.append(point)
72+
73+
if jnp.isnan(point).any():
74+
print('Nan detected!!!')
75+
return False, trajectory
76+
77+
if system.start_state(point) or system.target_state(point):
78+
break
79+
80+
while len(new_trajectory) < steps:
81+
key, iter_key = jax.random.split(key)
82+
point = system.step(new_trajectory[0], iter_key)
83+
new_trajectory.insert(0, point)
84+
85+
if jnp.isnan(point).any():
86+
return False, trajectory
87+
88+
if system.start_state(point) or system.target_state(point):
89+
break
90+
91+
# throw away the trajectory if it's not the right length
92+
if fixed_length != 0 and len(new_trajectory) != fixed_length:
93+
return False, trajectory
94+
95+
if system.start_state(new_trajectory[0]) and system.target_state(new_trajectory[-1]):
96+
return True, new_trajectory
97+
98+
if system.target_state(new_trajectory[0]) and system.start_state(new_trajectory[-1]):
99+
return True, new_trajectory[::-1]
100+
101+
return False, trajectory
102+
103+
104+
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_length=0, warmup=50):
105+
# pick an initial trajectory
106+
trajectories = [initial_trajectory]
107+
108+
with tqdm(total=num_paths) as pbar:
109+
while len(trajectories) <= num_paths + warmup:
110+
key, iter_key, accept_key = jax.random.split(key, 3)
111+
found, new_trajectory = proposal(system, trajectories[-1], fixed_length, iter_key)
112+
if not found:
113+
continue
114+
115+
ratio = len(trajectories[-1]) / len(new_trajectory)
116+
# The first trajectory might have a very unreasonable length, so we skip it
117+
if len(trajectories) == 1 or jax.random.uniform(accept_key, shape=(1,)) < ratio:
118+
print('path accepted')
119+
trajectories.append(new_trajectory)
120+
121+
if len(trajectories) > warmup:
122+
pbar.update(1)
123+
else:
124+
print('Rejected path')
125+
126+
return trajectories[warmup + 1:]
127+
128+
129+
def unguided_md(system, initial_point, num_paths, key, fixed_length=0):
130+
trajectories = []
131+
current_frame = initial_point.clone()
132+
current_trajectory = []
133+
134+
with tqdm(total=num_paths) as pbar:
135+
while len(trajectories) < num_paths:
136+
key, iter_key = jax.random.split(key)
137+
next_frame = system.step(current_frame, iter_key)
138+
139+
is_transition = not (system.start_state(next_frame) or system.target_state(next_frame))
140+
if is_transition:
141+
if len(current_trajectory) == 0:
142+
current_trajectory.append(current_frame)
143+
144+
if fixed_length != 0 and len(current_trajectory) > fixed_length:
145+
current_trajectory = []
146+
is_transition = False
147+
else:
148+
current_trajectory.append(next_frame)
149+
elif len(current_trajectory) > 0:
150+
current_trajectory.append(next_frame)
151+
152+
if fixed_length == 0 or len(current_trajectory) == fixed_length:
153+
if system.start_state(current_trajectory[0]) and system.target_state(current_trajectory[-1]):
154+
trajectories.append(current_trajectory)
155+
pbar.update(1)
156+
elif system.target_state(current_trajectory[0]) and system.start_state(current_trajectory[-1]):
157+
trajectories.append(current_trajectory[::-1])
158+
pbar.update(1)
159+
current_trajectory = []
160+
161+
current_frame = next_frame
162+
163+
return trajectories

tps_baseline.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import os
2+
import jax
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
import jax.numpy as jnp
6+
import optax
7+
from flax import linen as nn
8+
from flax.training import train_state
9+
from tqdm import trange
10+
11+
# install openmm (from conda)
12+
import openmm.app as app
13+
import openmm.unit as unit
14+
# install dmff (from source)
15+
from dmff import Hamiltonian, NeighborList
16+
# install mdtraj
17+
import mdtraj as md
18+
19+
import tps
20+
# helper function for plotting in 2D
21+
from utils.PlotPathsAlanine_jax import PlotPathsAlanine
22+
from matplotlib import colors
23+
24+
25+
def human_format(num):
26+
"""https://stackoverflow.com/a/45846841/4417954"""
27+
num = float('{:.3g}'.format(num))
28+
if num >= 1:
29+
magnitude = 0
30+
while abs(num) >= 1000:
31+
magnitude += 1
32+
num /= 1000.0
33+
return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude])
34+
else:
35+
magnitude = 0
36+
while abs(num) < 1:
37+
magnitude += 1
38+
num *= 1000.0
39+
return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'm', 'µ', 'n', 'p', 'f'][magnitude])
40+
41+
42+
def interpolate(points, steps):
43+
def interpolate_two_points(start, stop, steps):
44+
t = jnp.linspace(0, 1, steps + 1).reshape(steps + 1, 1)
45+
interpolated_tensors = jnp.array(start) * (1 - t) + jnp.array(stop) * t
46+
return interpolated_tensors
47+
48+
step_size = steps // (len(points) - 1)
49+
remaining = steps % (len(points) - 1)
50+
51+
interpolation = []
52+
for i in range(len(points) - 1):
53+
cur_step_size = step_size + (1 if i < remaining else 0)
54+
current = interpolate_two_points(points[i], points[i + 1], cur_step_size)
55+
interpolation.extend(current if i == 0 else current[1:])
56+
57+
return interpolation
58+
59+
60+
def phis_psis(position, mdtraj_topology):
61+
traj = md.Trajectory(position.reshape(-1, mdtraj_topology.n_atoms, 3), mdtraj_topology)
62+
phi = md.compute_phi(traj)[1].squeeze()
63+
psi = md.compute_psi(traj)[1].squeeze()
64+
return jnp.array([phi, psi]).T
65+
66+
67+
def ramachandran(samples, bins=100, path=None, paths=None):
68+
if samples is not None:
69+
plt.hist2d(samples[:, 0], samples[:, 1], bins=bins, norm=colors.LogNorm(), rasterized=True)
70+
plt.xlim(-np.pi, np.pi)
71+
plt.ylim(-np.pi, np.pi)
72+
73+
# set ticks
74+
plt.gca().set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
75+
plt.gca().set_xticklabels([r'$-\pi$', r'$-\frac {\pi} {2}$', '0', r'$\frac {\pi} {2}$', r'$\pi$'])
76+
77+
plt.gca().set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
78+
plt.gca().set_yticklabels([r'$-\pi$', r'$-\frac {\pi} {2}$', '0', r'$\frac {\pi} {2}$', r'$\pi$'])
79+
80+
plt.xlabel(r'$\phi$')
81+
plt.ylabel(r'$\psi$')
82+
83+
plt.gca().set_aspect('equal', adjustable='box')
84+
85+
def draw_path(_path, **kwargs):
86+
dist = jnp.sqrt(np.sum(jnp.diff(_path, axis=0) ** 2, axis=1))
87+
mask = jnp.hstack([dist > jnp.pi, jnp.array([False])])
88+
masked_path_x, masked_path_y = np.ma.MaskedArray(_path[:, 0], mask), np.ma.MaskedArray(_path[:, 1], mask)
89+
plt.plot(masked_path_x, masked_path_y, **kwargs)
90+
91+
92+
if path is not None:
93+
draw_path(path, color='red')
94+
95+
if paths is not None:
96+
for path in paths:
97+
draw_path(path, color='blue')
98+
99+
100+
T = 2.0
101+
dt = 1.0 * unit.microsecond
102+
dt = dt.value_in_unit(unit.second)
103+
104+
temp = 298.15
105+
temp = 10000 # TODO: remove this
106+
kbT = 1.380649 * 6.02214076 * 1e-3 * temp
107+
108+
if __name__ == '__main__':
109+
init_pdb = app.PDBFile("./files/AD_c7eq.pdb")
110+
target_pdb = app.PDBFile("./files/AD_c7ax.pdb")
111+
mdtraj_topology = md.Topology.from_openmm(init_pdb.topology)
112+
113+
savedir = f"baselines/alanine"
114+
os.makedirs(savedir, exist_ok=True)
115+
116+
# Construct the mass matrix
117+
mass = [a.element.mass.value_in_unit(unit.dalton) for a in init_pdb.topology.atoms()]
118+
new_mass = []
119+
for mass_ in mass:
120+
for _ in range(3):
121+
new_mass.append(mass_)
122+
mass = jnp.array(new_mass)
123+
# Obtain sigma, gamma is by default 1
124+
sigma = jnp.sqrt(2 * kbT / mass)
125+
126+
# Initial and target shape [BS, 66]
127+
A = jnp.array(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)).reshape(1, -1)
128+
B = jnp.array(target_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)).reshape(1, -1)
129+
130+
# Initialize the potential energy with amber forcefields
131+
ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml')
132+
potentials = ff.createPotential(init_pdb.topology,
133+
nonbondedMethod=app.NoCutoff,
134+
nonbondedCutoff=1.0 * unit.nanometers,
135+
constraints=None,
136+
ewaldErrorTolerance=0.0005)
137+
# Create a box used when calling
138+
# Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap
139+
box = np.array([[50.0, 0.0, 0.0], [0.0, 50.0, 0.0], [0.0, 0.0, 50.0]])
140+
nbList = NeighborList(box, 4.0, potentials.meta["cov_map"])
141+
nbList.allocate(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
142+
pairs = nbList.pairs
143+
144+
145+
def U(_x):
146+
_U = potentials.getPotentialFunc()
147+
148+
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters)
149+
150+
151+
#TODO: we can introduce gamma here
152+
def dUdx_fn(_x):
153+
return jax.grad(lambda _x: U(_x).sum())(_x) / mass
154+
155+
156+
dUdx_fn = jax.vmap(dUdx_fn)
157+
dUdx_fn = jax.jit(dUdx_fn)
158+
159+
@jax.jit
160+
def step(_x, _key):
161+
"""Perform one step of forward euler"""
162+
return _x - dt * dUdx_fn(_x) + jnp.sqrt(dt) * sigma * jax.random.normal(_key, _x.shape)
163+
164+
key = jax.random.PRNGKey(1)
165+
166+
trajectory = [A] # or [B]
167+
_x = trajectory[-1]
168+
steps = 20_000
169+
for i in trange(steps):
170+
key, iter_key = jax.random.split(key)
171+
_x = step(_x, iter_key)
172+
trajectory.append(_x)
173+
174+
trajectory = jnp.array(trajectory).reshape(-1, 66)
175+
assert not jnp.isnan(trajectory).any()
176+
trajectory_phi_psi = phis_psis(trajectory, mdtraj_topology)
177+
178+
trajs = None
179+
for i in range(10000, 11000):
180+
traj = md.load_pdb('./files/AD_c7eq.pdb')
181+
traj.xyz = trajectory[i].reshape(22, 3)
182+
if trajs is None:
183+
trajs = traj
184+
else:
185+
trajs = trajs.join(traj)
186+
trajs.save(f'{savedir}/ALDP_forward_euler.pdb')
187+
188+
plt.title(f"{human_format(steps)} steps @ {temp} K, dt = {human_format(dt)}s")
189+
ramachandran(trajectory_phi_psi)
190+
plt.show()
191+
192+
193+
# TODO: this is work in progress. Get some baselines with tps
194+
system = tps.System(
195+
jax.jit(
196+
lambda s: jnp.all(jnp.linalg.norm(A.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
197+
jax.jit(
198+
lambda s: jnp.all(jnp.linalg.norm(B.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
199+
step
200+
)
201+
202+
# initial_trajectory = [t.reshape(1, -1) for t in interpolate([A, B], 100)]
203+
204+
#
205+
# for i in range(10):
206+
# key, iter_key = jax.random.split(key)
207+
#
208+
# # ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology))
209+
# # plt.show()
210+
#
211+
#
212+
# ok, trajectory = tps.one_way_shooting(system, initial_trajectory, 0, key)
213+
# trajectory = jnp.array(trajectory)
214+
# trajectory = phis_psis(trajectory, mdtraj_topology)
215+
# print('ok?', ok)
216+
#
217+
# ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology), paths=[trajectory])
218+
# plt.show()
219+
#
220+
221+
# paths = tps.mcmc_shooting(system, tps.one_way_shooting, initial_trajectory, 10, key, warmup=0)
222+
# paths = [jnp.array(p) for p in paths]
223+
#
224+
# print(paths)
225+
# ramachandran(None, path=[phis_psis(p, mdtraj_topology) for p in paths][-1])
226+
# plt.show()

0 commit comments

Comments
 (0)