Skip to content

Commit 4d8d642

Browse files
committed
Implement dihedral angles in jax
1 parent 9d95064 commit 4d8d642

File tree

3 files changed

+88
-27
lines changed

3 files changed

+88
-27
lines changed

tps_baseline.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# helper function for plotting in 2D
2020
from matplotlib import colors
2121

22+
from utils.angles import phi_psi_from_mdtraj
2223
from utils.animation import save_trajectory, to_md_traj
2324
from utils.rmsd import kabsch_align, kabsch_rmsd
2425

@@ -58,13 +59,6 @@ def interpolate_two_points(start, stop, steps):
5859
return interpolation
5960

6061

61-
def phis_psis(position, mdtraj_topology):
62-
traj = to_md_traj(mdtraj_topology, position)
63-
phi = md.compute_phi(traj)[1].squeeze()
64-
psi = md.compute_psi(traj)[1].squeeze()
65-
return jnp.array([phi, psi]).T
66-
67-
6862
def ramachandran(samples, bins=100, path=None, paths=None, states=None, alpha=1.0):
6963
if samples is not None:
7064
plt.hist2d(samples[:, 0], samples[:, 1], bins=bins, norm=colors.LogNorm(), rasterized=True)
@@ -97,7 +91,8 @@ def draw_path(_path, **kwargs):
9791
draw_path(path, color='blue')
9892

9993
for state in (states if states is not None else []):
100-
c = plt.Circle(state['center'], radius=state['radius'], edgecolor='gray', facecolor='white', ls='--', lw=0.7, alpha=alpha)
94+
c = plt.Circle(state['center'], radius=state['radius'], edgecolor='gray', facecolor='white', ls='--', lw=0.7,
95+
alpha=alpha)
10196
plt.gca().add_patch(c)
10297
plt.gca().annotate(state['name'], xy=state['center'], ha="center", va="center")
10398

@@ -131,6 +126,7 @@ def is_within(_phis_psis, _center, _radius, _period=2 * jnp.pi):
131126
init_pdb = app.PDBFile("./files/AD_A.pdb")
132127
target_pdb = app.PDBFile("./files/AD_B.pdb")
133128
mdtraj_topology = md.Topology.from_openmm(init_pdb.topology)
129+
phis_psis = phi_psi_from_mdtraj(mdtraj_topology)
134130

135131
savedir = f"baselines/alanine"
136132
os.makedirs(savedir, exist_ok=True)
@@ -151,7 +147,6 @@ def is_within(_phis_psis, _center, _radius, _period=2 * jnp.pi):
151147
A, B = kabsch_align(A, B)
152148
A, B = A.reshape(1, -1), B.reshape(1, -1)
153149

154-
155150
# Initialize the potential energy with amber forcefields
156151
ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml')
157152
potentials = ff.createPotential(init_pdb.topology,
@@ -236,12 +231,12 @@ def step_langevin_backward(_x, _v, _key):
236231

237232
# we only need to check whether the last frame contains nan, is it propagates
238233
assert not jnp.isnan(trajectory[-1]).any()
239-
trajectory_phi_psi = phis_psis(trajectory, mdtraj_topology)
234+
trajectory_phi_psi = phis_psis(trajectory)
240235

241236
plt.title(f"{human_format(steps)} steps @ {temp} K, dt = {human_format(dt)}s")
242237
ramachandran(trajectory_phi_psi)
243-
plt.scatter(phis_psis(A, mdtraj_topology)[0], phis_psis(A, mdtraj_topology)[1], color='red', marker='*')
244-
plt.scatter(phis_psis(B, mdtraj_topology)[0], phis_psis(B, mdtraj_topology)[1], color='green', marker='*')
238+
plt.scatter(phis_psis(A)[0], phis_psis(A)[1], color='red', marker='*')
239+
plt.scatter(phis_psis(B)[0], phis_psis(B)[1], color='green', marker='*')
245240
plt.show()
246241

247242
# Choose a system, either phi psi, or rmsd
@@ -254,23 +249,23 @@ def step_langevin_backward(_x, _v, _key):
254249
radius = 20 / deg
255250

256251
system = tps1.FirstOrderSystem(
257-
lambda s: is_within(phis_psis(s, mdtraj_topology).reshape(-1, 2), phis_psis(A, mdtraj_topology), radius),
258-
lambda s: is_within(phis_psis(s, mdtraj_topology).reshape(-1, 2), phis_psis(B, mdtraj_topology), radius),
252+
lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius),
253+
lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius),
259254
step
260255
)
261256

262257
system = tps2.SecondOrderSystem(
263-
# lambda s: is_within(phis_psis(s, mdtraj_topology).reshape(-1, 2), phis_psis(A, mdtraj_topology), radius),
264-
# lambda s: is_within(phis_psis(s, mdtraj_topology).reshape(-1, 2), phis_psis(B, mdtraj_topology), radius),
265-
jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
266-
jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
258+
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
259+
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
260+
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
261+
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
267262
step_langevin_forward,
268263
step_langevin_backward,
269264
jax.jit(lambda key: jnp.sqrt(kbT / mass) * jax.random.normal(key, (1, 66)))
270265
)
271266

272-
print("A", phis_psis(A, mdtraj_topology))
273-
print("B", phis_psis(B, mdtraj_topology))
267+
print("A", phis_psis(A))
268+
print("B", phis_psis(B))
274269

275270
filter1 = system.start_state(trajectory)
276271
filter2 = system.target_state(trajectory)
@@ -291,7 +286,8 @@ def step_langevin_backward(_x, _v, _key):
291286
if load:
292287
paths = np.load(f'{savedir}/paths.npy', allow_pickle=True)
293288
else:
294-
paths = tps2.mcmc_shooting(system, tps2.two_way_shooting, initial_trajectory, 5, jax.random.PRNGKey(1), warmup=0)
289+
paths = tps2.mcmc_shooting(system, tps2.two_way_shooting, initial_trajectory, 100, jax.random.PRNGKey(1),
290+
warmup=10)
295291
# paths = tps2.unguided_md(system, B, 1, key)
296292
paths = [jnp.array(p) for p in paths]
297293
# store paths
@@ -303,16 +299,17 @@ def step_langevin_backward(_x, _v, _key):
303299

304300
path_hist = PeriodicPathHistogram()
305301
for i, path in tqdm(enumerate(paths)):
306-
path_hist.add_path(np.array(phis_psis(path, mdtraj_topology)))
302+
path_hist.add_path(np.array(phis_psis(path)))
307303

308304
plt.title(f"{human_format(len(paths))} paths @ {temp} K, dt = {human_format(dt)}s")
309305
path_hist.plot(cmin=0.001)
310306
ramachandran(None, states=[
311-
{'name': 'A', 'center': phis_psis(A, mdtraj_topology), 'radius': radius},
312-
{'name': 'B', 'center': phis_psis(B, mdtraj_topology), 'radius': radius},
307+
{'name': 'A', 'center': phis_psis(A).squeeze(), 'radius': radius},
308+
{'name': 'B', 'center': phis_psis(B).squeeze(), 'radius': radius},
313309
], alpha=0.7)
314310
plt.savefig(f'{savedir}/paths.png', bbox_inches='tight')
315311
plt.show()
316312

317313
for i, path in tqdm(enumerate(paths)):
318-
save_trajectory(mdtraj_topology, jnp.array([kabsch_align(p.reshape(-1, 3), B.reshape(-1, 3))[0] for p in path]), f'{savedir}/trajectory_{i}.pdb')
314+
save_trajectory(mdtraj_topology, jnp.array([kabsch_align(p.reshape(-1, 3), B.reshape(-1, 3))[0] for p in path]),
315+
f'{savedir}/trajectory_{i}.pdb')

utils/PlotPathsAlanine_jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def PlotPathsAlanine(paths, target, filename):
104104
plt.xlim([-np.pi, np.pi])
105105
plt.ylim([-np.pi, np.pi])
106106

107-
angle_2 = [1, 6, 8, 14]
108-
angle_1 = [6, 8, 14, 16]
107+
angle_2 = [1, 6, 8, 14] # phi
108+
angle_1 = [6, 8, 14, 16] # psi
109109

110110
potential = AlaninePotential()
111111
xs = np.arange(-np.pi, np.pi + .1, .1)

utils/angles.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import jax.numpy as jnp
2+
import jax
3+
from mdtraj.geometry import indices_phi, indices_psi
4+
5+
@jax.jit
6+
def dihedral(p):
7+
"""http://stackoverflow.com/q/20305272/1128289"""
8+
b = p[:-1] - p[1:]
9+
b = b.at[0].set(-b[0])
10+
v = jnp.array(
11+
[v - (v.dot(b[1]) / b[1].dot(b[1])) * b[1] for v in [b[0], b[2]]])
12+
# Normalize vectors
13+
v /= jnp.sqrt(jnp.einsum('...i,...i', v, v)).reshape(-1, 1)
14+
b1 = b[1] / jnp.linalg.norm(b[1])
15+
x = jnp.dot(v[0], v[1])
16+
m = jnp.cross(v[0], b1)
17+
y = jnp.dot(m, v[1])
18+
return jnp.arctan2(y, x)
19+
20+
21+
def phi_psi_from_mdtraj(mdtraj_topology):
22+
angles_phi = indices_phi(mdtraj_topology)[0]
23+
angles_psi = indices_psi(mdtraj_topology)[0]
24+
25+
assert len(angles_phi) == len(angles_psi) == 4
26+
27+
@jax.jit
28+
@jax.vmap
29+
def phi_psi(p):
30+
p = p.reshape(mdtraj_topology.n_atoms, 3)
31+
phi = dihedral(p[angles_phi, :])
32+
psi = dihedral(p[angles_psi, :])
33+
34+
return jnp.array([phi, psi])
35+
36+
return phi_psi
37+
38+
39+
if __name__ == '__main__':
40+
import openmm.app as app
41+
import openmm.unit as unit
42+
import mdtraj as md
43+
from utils.animation import to_md_traj
44+
45+
init_pdb = app.PDBFile("../files/AD_A.pdb")
46+
target_pdb = app.PDBFile("../files/AD_B.pdb")
47+
mdtraj_topology = md.Topology.from_openmm(init_pdb.topology)
48+
49+
A = jnp.array(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
50+
B = jnp.array(target_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
51+
52+
phi_psi = phi_psi_from_mdtraj(mdtraj_topology)
53+
print(phi_psi(A.reshape(1, 22, 3)))
54+
print(phi_psi(B.reshape(1, 22, 3)))
55+
56+
traj = to_md_traj(mdtraj_topology, A)
57+
phi = md.compute_phi(traj)[1].squeeze()
58+
psi = md.compute_psi(traj)[1].squeeze()
59+
print(phi, psi)
60+
61+
traj = to_md_traj(mdtraj_topology, B)
62+
phi = md.compute_phi(traj)[1].squeeze()
63+
psi = md.compute_psi(traj)[1].squeeze()
64+
print(phi, psi)

0 commit comments

Comments
 (0)