Skip to content

Commit 356ef0e

Browse files
committed
Update TPS baselines
1 parent 019880b commit 356ef0e

File tree

5 files changed

+208
-141
lines changed

5 files changed

+208
-141
lines changed

tps.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ def one_way_shooting(system, trajectory, fixed_length, key):
2020
# pick a random direction, either forward or backward
2121
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
2222

23-
print(f'testing path with id={point_idx} and direction={"forward" if direction == 0 else "backward"}')
24-
2523
if direction == 0:
2624
trajectory = trajectory[:point_idx + 1]
2725

@@ -37,7 +35,6 @@ def one_way_shooting(system, trajectory, fixed_length, key):
3735
trajectory.append(point)
3836

3937
if jnp.isnan(point).any():
40-
print('Nan detected!!!')
4138
return False, trajectory
4239

4340
if system.start_state(trajectory[0]) and system.target_state(trajectory[-1]):
@@ -71,7 +68,6 @@ def two_way_shooting(system, trajectory, fixed_length, key):
7168
new_trajectory.append(point)
7269

7370
if jnp.isnan(point).any():
74-
print('Nan detected!!!')
7571
return False, trajectory
7672

7773
if system.start_state(point) or system.target_state(point):
@@ -115,13 +111,10 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_le
115111
ratio = len(trajectories[-1]) / len(new_trajectory)
116112
# The first trajectory might have a very unreasonable length, so we skip it
117113
if len(trajectories) == 1 or jax.random.uniform(accept_key, shape=(1,)) < ratio:
118-
print('path accepted')
119114
trajectories.append(new_trajectory)
120115

121116
if len(trajectories) > warmup:
122117
pbar.update(1)
123-
else:
124-
print('Rejected path')
125118

126119
return trajectories[warmup + 1:]
127120

tps_baseline.py

Lines changed: 70 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from utils.PlotPathsAlanine_jax import PlotPathsAlanine
2222
from matplotlib import colors
2323

24-
from utils.rmsd import kabsch
25-
from scipy.constants import physical_constants
24+
from utils.animation import save_trajectory, to_md_traj
25+
from utils.rmsd import kabsch_align, kabsch_rmsd
2626

2727

2828
def human_format(num):
@@ -61,7 +61,7 @@ def interpolate_two_points(start, stop, steps):
6161

6262

6363
def phis_psis(position, mdtraj_topology):
64-
traj = md.Trajectory(position.reshape(-1, mdtraj_topology.n_atoms, 3), mdtraj_topology)
64+
traj = to_md_traj(mdtraj_topology, position)
6565
phi = md.compute_phi(traj)[1].squeeze()
6666
psi = md.compute_psi(traj)[1].squeeze()
6767
return jnp.array([phi, psi]).T
@@ -99,22 +99,31 @@ def draw_path(_path, **kwargs):
9999
draw_path(path, color='blue')
100100

101101

102-
T = 2.0
103-
dt_as_unit = unit.Quantity(value=1.0, unit=unit.femtoseconds)
102+
dt_as_unit = unit.Quantity(value=1, unit=unit.microsecond)
104103
dt_in_ps = dt_as_unit.value_in_unit(unit.picosecond)
105104
dt = dt_as_unit.value_in_unit(unit.second)
106105

107-
gamma_as_unit = 1.0 / unit.picosecond
106+
gamma_as_unit = 1.0 / unit.second
108107
# actually gamma is 1/s, but we are working without units and just need the correct scaling
109108
# TODO: try to get rid of this duplicate definition
110-
gamma = 1.0 * unit.picosecond
109+
gamma = 1.0 * unit.second
111110
gamma_in_ps = gamma.value_in_unit(unit.picosecond)
112111
gamma = gamma.value_in_unit(unit.second)
113112

114-
temp = 298.15
113+
temp = 300
115114
kbT = 1.380649 * 6.02214076 * 1e-3 * temp
116115

117116

117+
@jax.jit
118+
def is_within(_phis_psis, _center, _radius, _period=2 * jnp.pi):
119+
delta = jnp.abs(_center - _phis_psis)
120+
delta = jnp.where(delta > _period / 2, delta - _period, delta)
121+
122+
return jnp.hypot(delta[:, 0], delta[:, 1]) < _radius
123+
124+
125+
deg = 180.0 / jnp.pi
126+
118127
if __name__ == '__main__':
119128
init_pdb = app.PDBFile("./files/AD_c7eq.pdb")
120129
target_pdb = app.PDBFile("./files/AD_c7ax.pdb")
@@ -130,12 +139,14 @@ def draw_path(_path, **kwargs):
130139
for _ in range(3):
131140
new_mass.append(mass_)
132141
mass = jnp.array(new_mass)
133-
# Obtain sigma, gamma is by default 1
134-
sigma = jnp.sqrt(2 * kbT / mass / gamma)
142+
# Obtain xi, gamma is by default 1
143+
xi = jnp.sqrt(2 * kbT / mass / gamma)
135144

136145
# Initial and target shape [BS, 66]
137-
A = jnp.array(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)).reshape(1, -1)
138-
B = jnp.array(target_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)).reshape(1, -1)
146+
A = jnp.array(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
147+
B = jnp.array(target_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
148+
A, B = kabsch_align(A, B)
149+
A, B = A.reshape(1, -1), B.reshape(1, -1)
139150

140151
# Initialize the potential energy with amber forcefields
141152
ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml')
@@ -160,150 +171,82 @@ def U(_x):
160171
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters)
161172

162173

174+
@jax.jit
175+
@jax.vmap
163176
def dUdx_fn(_x):
164177
return jax.grad(lambda _x: U(_x).sum())(_x) / mass / gamma
165178

166179

167-
dUdx_fn = jax.vmap(dUdx_fn)
168-
dUdx_fn = jax.jit(dUdx_fn)
169-
170-
171180
@jax.jit
172181
def step(_x, _key):
173182
"""Perform one step of forward euler"""
174-
return _x - dt * dUdx_fn(_x) + jnp.sqrt(dt) * sigma * jax.random.normal(_key, _x.shape)
175-
176-
177-
def dUdx_fn_unscaled(_x):
178-
return jax.grad(lambda _x: U(_x).sum())(_x)
183+
return _x - dt * dUdx_fn(_x) + jnp.sqrt(dt) * xi * jax.random.normal(_key, _x.shape)
179184

180-
dUdx_fn_unscaled = jax.vmap(dUdx_fn_unscaled)
181-
dUdx_fn_unscaled = jax.jit(dUdx_fn_unscaled)
182-
183-
@jax.jit
184-
def step_langevin(_x, _v, _key):
185-
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
186-
f_scale = (1 - alpha) / gamma_in_ps
187-
new_v_det = alpha * _v + f_scale * -dUdx_fn_unscaled(_x) / mass
188-
new_v = new_v_det + jnp.sqrt(kbT * (1 - alpha ** 2) / mass) * jax.random.normal(_key, _x.shape)
189-
190-
return _x + dt_in_ps * new_v, new_v
191185

192186
key = jax.random.PRNGKey(1)
193187
key, velocity_key = jax.random.split(key)
194-
steps = 1_000_000
188+
steps = 100_000
195189

196190
trajectory = [A]
197191
_x = trajectory[-1]
198192

199-
velocity_variance = unit.Quantity(1 / mass, unit=1 / unit.dalton) * unit.BOLTZMANN_CONSTANT_kB * unit.Quantity(temp, unit=unit.kelvin)
200-
# Although velocity+variance is of the unit J / Da = m^2 / s^2, openmm cannot handle this directly and we need to convert it
201-
velocity_variance_in_si = 1 / physical_constants['unified atomic mass unit'][
202-
0] * velocity_variance.value_in_unit(unit.joule / unit.dalton)
203-
# velocity_variance_in_si = unit.Quantity(velocity_variance_in_si, unit.meter / unit.second)
204-
205-
_v = jnp.sqrt(velocity_variance_in_si) * jax.random.normal(velocity_key, _x.shape)
206-
_v = unit.Quantity(_v, unit.meter / unit.second).value_in_unit(unit.nanometer / unit.picosecond)
207-
208193
for i in trange(steps):
209194
key, iter_key = jax.random.split(key)
210-
_x, _v = step_langevin(_x, _v, iter_key)
195+
_x = step(_x, iter_key)
211196

212197
trajectory.append(_x)
213198

214199
trajectory = jnp.array(trajectory).reshape(-1, 66)
215200

201+
# save_trajectory(mdtraj_topology, trajectory[-1000:], 'simulation.pdb')
202+
216203
# we only need to check whether the last frame contains nan, is it propagates
217204
assert not jnp.isnan(trajectory[-1]).any()
218205
trajectory_phi_psi = phis_psis(trajectory, mdtraj_topology)
219206

220-
trajs = None
221-
for i in range(10000, 11000):
222-
traj = md.load_pdb('./files/AD_c7eq.pdb')
223-
traj.xyz = trajectory[i].reshape(22, 3)
224-
if trajs is None:
225-
trajs = traj
226-
else:
227-
trajs = trajs.join(traj)
228-
trajs.save(f'{savedir}/ALDP_forward_euler.pdb')
229-
230207
plt.title(f"{human_format(steps)} steps @ {temp} K, dt = {human_format(dt)}s")
231208
ramachandran(trajectory_phi_psi)
232209
plt.show()
233210

234-
# TODO: this is work in progress. Get some baselines with tps
235-
236-
# l2_system = tps.System(
237-
# jax.jit(
238-
# lambda s: jnp.all(jnp.linalg.norm(A.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
239-
# jax.jit(
240-
# lambda s: jnp.all(jnp.linalg.norm(B.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
211+
# Choose a system, either phi psi, or rmsd
212+
# system = tps.System(
213+
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) < 0.1)),
214+
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) < 0.1)),
241215
# step
242216
# )
243-
#
244-
# rmsd_system = tps.System(
245-
# jax.jit(lambda s: kabsch(A.reshape(22, 3), s.reshape(22, 3)) < 0.15),
246-
# jax.jit(lambda s: kabsch(B.reshape(22, 3), s.reshape(22, 3)) < 0.15),
247-
# step
248-
# )
249-
#
250-
# # @jax.jit
251-
# def is_within_phi_psi(s, center, radius, period=2 * jnp.pi):
252-
# points = phis_psis(s, mdtraj_topology)
253-
# delta = jnp.abs(center - points)
254-
# delta = jnp.where(delta > period / 2, delta - period, delta)
255-
#
256-
# return jnp.hypot(delta[:, 0], delta[:, 1]) < radius
257-
#
258-
#
259-
# deg = 180.0 / jnp.pi
260-
# # State('A', torch.tensor([-150, 150]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
261-
# # State('B', torch.tensor([-70, 135]) / deg, torch.tensor([20, 45, 65, 75]) / deg),
262-
# # State('C', torch.tensor([-150, -65]) / deg, torch.tensor([20, 45, 60]) / deg),
263-
# # State('D', torch.tensor([-70, -50]) / deg, torch.tensor([20, 45, 60]) / deg),
264-
# # State('E', torch.tensor([50, -100]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
265-
# # State('F', torch.tensor([40, 65]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
266-
#
267-
# phi_psi_system = tps.System(
268-
# lambda s: is_within_phi_psi(s, jnp.array([-150, 150]) / deg, 20 / deg),
269-
# lambda s: is_within_phi_psi(s, jnp.array([50, -100]) / deg, 20 / deg),
270-
# step
271-
# )
272-
#
273-
# # TODO: fix vmap
274-
# filter1 = jax.vmap(phi_psi_system.start_state)(trajectory)
275-
# filter2 = jax.vmap(phi_psi_system.target_state)(trajectory)
276-
#
277-
# plt.title('start')
278-
# ramachandran(trajectory_phi_psi[filter1])
279-
# plt.show()
280-
#
281-
# plt.title('target')
282-
# ramachandran(trajectory_phi_psi[filter2])
283-
# plt.show()
284-
285-
# initial_trajectory = [t.reshape(1, -1) for t in interpolate([A, B], 100)]
286-
287-
#
288-
# for i in range(10):
289-
# key, iter_key = jax.random.split(key)
290-
#
291-
# # ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology))
292-
# # plt.show()
293-
#
294-
#
295-
# ok, trajectory = tps.one_way_shooting(system, initial_trajectory, 0, key)
296-
# trajectory = jnp.array(trajectory)
297-
# trajectory = phis_psis(trajectory, mdtraj_topology)
298-
# print('ok?', ok)
299-
#
300-
# ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology), paths=[trajectory])
301-
# plt.show()
302-
#
303-
304-
# paths = tps.mcmc_shooting(system, tps.one_way_shooting, initial_trajectory, 10, key, warmup=0)
305-
# paths = [jnp.array(p) for p in paths]
306-
#
307-
# print(paths)
308-
# ramachandran(None, path=[phis_psis(p, mdtraj_topology) for p in paths][-1])
309-
# plt.show()
217+
218+
system = tps.System(
219+
lambda s: is_within(phis_psis(s, mdtraj_topology).reshape(-1, 2), phis_psis(A, mdtraj_topology), 20 / deg),
220+
lambda s: is_within(phis_psis(s, mdtraj_topology).reshape(-1, 2), phis_psis(B, mdtraj_topology), 20 / deg),
221+
step
222+
)
223+
224+
filter1 = system.start_state(trajectory)
225+
filter2 = system.target_state(trajectory)
226+
227+
plt.title('start')
228+
ramachandran(trajectory_phi_psi[filter1])
229+
plt.show()
230+
231+
plt.title('target')
232+
ramachandran(trajectory_phi_psi[filter2])
233+
plt.show()
234+
235+
initial_trajectory = [t.reshape(1, -1) for t in interpolate([A, B], 100)]
236+
save_trajectory(mdtraj_topology, jnp.array(initial_trajectory), f'{savedir}/initial_trajectory.pdb')
237+
238+
paths = tps.mcmc_shooting(system, tps.two_way_shooting, initial_trajectory, 5, key, warmup=2)
239+
paths = [jnp.array(p) for p in paths]
240+
# store paths
241+
np.save(f'{savedir}/paths.npy', np.array(paths, dtype=object), allow_pickle=True)
242+
243+
print([len(p) for p in paths])
244+
plt.hist([len(p) for p in paths], bins=jnp.sqrt(len(paths)).astype(int).item())
245+
plt.show()
246+
247+
plt.title(f"{human_format(len(paths))} steps @ {temp} K, dt = {human_format(dt)}s")
248+
ramachandran(jnp.concatenate([phis_psis(p, mdtraj_topology) for p in paths]),
249+
path=phis_psis(jnp.array(initial_trajectory), mdtraj_topology))
250+
plt.show()
251+
252+
save_trajectory(mdtraj_topology, paths[-1], f'{savedir}/final_trajectory.pdb')

0 commit comments

Comments
 (0)