Skip to content

Commit d8bb79e

Browse files
committed
Add kabsch and second order integration
1 parent fb71589 commit d8bb79e

File tree

3 files changed

+149
-20
lines changed

3 files changed

+149
-20
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ dependencies:
1616
- dmff @ git+https://github.com/deepmodeling/[email protected]
1717
- matplotlib==3.8.2
1818
- rdkit==2023.3.3
19+
- ParmEd==4.2.2

tps_baseline.py

Lines changed: 103 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from utils.PlotPathsAlanine_jax import PlotPathsAlanine
2222
from matplotlib import colors
2323

24+
from utils.rmsd import kabsch
25+
from scipy.constants import physical_constants
26+
2427

2528
def human_format(num):
2629
"""https://stackoverflow.com/a/45846841/4417954"""
@@ -88,7 +91,6 @@ def draw_path(_path, **kwargs):
8891
masked_path_x, masked_path_y = np.ma.MaskedArray(_path[:, 0], mask), np.ma.MaskedArray(_path[:, 1], mask)
8992
plt.plot(masked_path_x, masked_path_y, **kwargs)
9093

91-
9294
if path is not None:
9395
draw_path(path, color='red')
9496

@@ -98,13 +100,21 @@ def draw_path(_path, **kwargs):
98100

99101

100102
T = 2.0
101-
dt = 1.0 * unit.microsecond
102-
dt = dt.value_in_unit(unit.second)
103+
dt_as_unit = unit.Quantity(value=1.0, unit=unit.femtoseconds)
104+
dt_in_ps = dt_as_unit.value_in_unit(unit.picosecond)
105+
dt = dt_as_unit.value_in_unit(unit.second)
106+
107+
gamma_as_unit = 1.0 / unit.picosecond
108+
# actually gamma is 1/s, but we are working without units and just need the correct scaling
109+
# TODO: try to get rid of this duplicate definition
110+
gamma = 1.0 * unit.picosecond
111+
gamma_in_ps = gamma.value_in_unit(unit.picosecond)
112+
gamma = gamma.value_in_unit(unit.second)
103113

104114
temp = 298.15
105-
temp = 10000 # TODO: remove this
106115
kbT = 1.380649 * 6.02214076 * 1e-3 * temp
107116

117+
108118
if __name__ == '__main__':
109119
init_pdb = app.PDBFile("./files/AD_c7eq.pdb")
110120
target_pdb = app.PDBFile("./files/AD_c7ax.pdb")
@@ -121,7 +131,7 @@ def draw_path(_path, **kwargs):
121131
new_mass.append(mass_)
122132
mass = jnp.array(new_mass)
123133
# Obtain sigma, gamma is by default 1
124-
sigma = jnp.sqrt(2 * kbT / mass)
134+
sigma = jnp.sqrt(2 * kbT / mass / gamma)
125135

126136
# Initial and target shape [BS, 66]
127137
A = jnp.array(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)).reshape(1, -1)
@@ -135,44 +145,76 @@ def draw_path(_path, **kwargs):
135145
constraints=None,
136146
ewaldErrorTolerance=0.0005)
137147
# Create a box used when calling
138-
# Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap
139148
box = np.array([[50.0, 0.0, 0.0], [0.0, 50.0, 0.0], [0.0, 0.0, 50.0]])
140149
nbList = NeighborList(box, 4.0, potentials.meta["cov_map"])
141150
nbList.allocate(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
142151
pairs = nbList.pairs
143152

144153

145154
def U(_x):
155+
"""
156+
Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap
157+
"""
146158
_U = potentials.getPotentialFunc()
147159

148160
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters)
149161

150162

151-
#TODO: we can introduce gamma here
152163
def dUdx_fn(_x):
153-
return jax.grad(lambda _x: U(_x).sum())(_x) / mass
164+
return jax.grad(lambda _x: U(_x).sum())(_x) / mass / gamma
154165

155166

156167
dUdx_fn = jax.vmap(dUdx_fn)
157168
dUdx_fn = jax.jit(dUdx_fn)
158169

170+
159171
@jax.jit
160172
def step(_x, _key):
161173
"""Perform one step of forward euler"""
162174
return _x - dt * dUdx_fn(_x) + jnp.sqrt(dt) * sigma * jax.random.normal(_key, _x.shape)
163175

176+
177+
def dUdx_fn_unscaled(_x):
178+
return jax.grad(lambda _x: U(_x).sum())(_x)
179+
180+
dUdx_fn_unscaled = jax.vmap(dUdx_fn_unscaled)
181+
dUdx_fn_unscaled = jax.jit(dUdx_fn_unscaled)
182+
183+
@jax.jit
184+
def step_langevin(_x, _v, _key):
185+
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
186+
f_scale = (1 - alpha) / gamma_in_ps
187+
new_v_det = alpha * _v + f_scale * -dUdx_fn_unscaled(_x) / mass
188+
new_v = new_v_det + jnp.sqrt(kbT * (1 - alpha ** 2) / mass) * jax.random.normal(_key, _x.shape)
189+
190+
return _x + dt_in_ps * new_v, new_v
191+
164192
key = jax.random.PRNGKey(1)
193+
key, velocity_key = jax.random.split(key)
194+
steps = 1_000_000
165195

166-
trajectory = [A] # or [B]
196+
trajectory = [A]
167197
_x = trajectory[-1]
168-
steps = 20_000
198+
199+
velocity_variance = unit.Quantity(1 / mass, unit=1 / unit.dalton) * unit.BOLTZMANN_CONSTANT_kB * unit.Quantity(temp, unit=unit.kelvin)
200+
# Although velocity+variance is of the unit J / Da = m^2 / s^2, openmm cannot handle this directly and we need to convert it
201+
velocity_variance_in_si = 1 / physical_constants['unified atomic mass unit'][
202+
0] * velocity_variance.value_in_unit(unit.joule / unit.dalton)
203+
# velocity_variance_in_si = unit.Quantity(velocity_variance_in_si, unit.meter / unit.second)
204+
205+
_v = jnp.sqrt(velocity_variance_in_si) * jax.random.normal(velocity_key, _x.shape)
206+
_v = unit.Quantity(_v, unit.meter / unit.second).value_in_unit(unit.nanometer / unit.picosecond)
207+
169208
for i in trange(steps):
170209
key, iter_key = jax.random.split(key)
171-
_x = step(_x, iter_key)
210+
_x, _v = step_langevin(_x, _v, iter_key)
211+
172212
trajectory.append(_x)
173213

174214
trajectory = jnp.array(trajectory).reshape(-1, 66)
175-
assert not jnp.isnan(trajectory).any()
215+
216+
# we only need to check whether the last frame contains nan, is it propagates
217+
assert not jnp.isnan(trajectory[-1]).any()
176218
trajectory_phi_psi = phis_psis(trajectory, mdtraj_topology)
177219

178220
trajs = None
@@ -189,15 +231,56 @@ def step(_x, _key):
189231
ramachandran(trajectory_phi_psi)
190232
plt.show()
191233

192-
193234
# TODO: this is work in progress. Get some baselines with tps
194-
system = tps.System(
195-
jax.jit(
196-
lambda s: jnp.all(jnp.linalg.norm(A.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
197-
jax.jit(
198-
lambda s: jnp.all(jnp.linalg.norm(B.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
199-
step
200-
)
235+
236+
# l2_system = tps.System(
237+
# jax.jit(
238+
# lambda s: jnp.all(jnp.linalg.norm(A.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
239+
# jax.jit(
240+
# lambda s: jnp.all(jnp.linalg.norm(B.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
241+
# step
242+
# )
243+
#
244+
# rmsd_system = tps.System(
245+
# jax.jit(lambda s: kabsch(A.reshape(22, 3), s.reshape(22, 3)) < 0.15),
246+
# jax.jit(lambda s: kabsch(B.reshape(22, 3), s.reshape(22, 3)) < 0.15),
247+
# step
248+
# )
249+
#
250+
# # @jax.jit
251+
# def is_within_phi_psi(s, center, radius, period=2 * jnp.pi):
252+
# points = phis_psis(s, mdtraj_topology)
253+
# delta = jnp.abs(center - points)
254+
# delta = jnp.where(delta > period / 2, delta - period, delta)
255+
#
256+
# return jnp.hypot(delta[:, 0], delta[:, 1]) < radius
257+
#
258+
#
259+
# deg = 180.0 / jnp.pi
260+
# # State('A', torch.tensor([-150, 150]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
261+
# # State('B', torch.tensor([-70, 135]) / deg, torch.tensor([20, 45, 65, 75]) / deg),
262+
# # State('C', torch.tensor([-150, -65]) / deg, torch.tensor([20, 45, 60]) / deg),
263+
# # State('D', torch.tensor([-70, -50]) / deg, torch.tensor([20, 45, 60]) / deg),
264+
# # State('E', torch.tensor([50, -100]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
265+
# # State('F', torch.tensor([40, 65]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
266+
#
267+
# phi_psi_system = tps.System(
268+
# lambda s: is_within_phi_psi(s, jnp.array([-150, 150]) / deg, 20 / deg),
269+
# lambda s: is_within_phi_psi(s, jnp.array([50, -100]) / deg, 20 / deg),
270+
# step
271+
# )
272+
#
273+
# # TODO: fix vmap
274+
# filter1 = jax.vmap(phi_psi_system.start_state)(trajectory)
275+
# filter2 = jax.vmap(phi_psi_system.target_state)(trajectory)
276+
#
277+
# plt.title('start')
278+
# ramachandran(trajectory_phi_psi[filter1])
279+
# plt.show()
280+
#
281+
# plt.title('target')
282+
# ramachandran(trajectory_phi_psi[filter2])
283+
# plt.show()
201284

202285
# initial_trajectory = [t.reshape(1, -1) for t in interpolate([A, B], 100)]
203286

utils/rmsd.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import jax.numpy as jnp
2+
import jax
3+
4+
"""
5+
From https://hunterheidenreich.com/posts/kabsch_algorithm/ and adapted
6+
"""
7+
8+
9+
@jax.jit
10+
def kabsch(P, Q):
11+
"""
12+
Computes the optimal rotation and translation to align two sets of points (P -> Q),
13+
and their RMSD.
14+
15+
:param P: A Nx3 matrix of points
16+
:param Q: A Nx3 matrix of points
17+
:return: A tuple containing the optimal rotation matrix, the optimal
18+
translation vector, and the RMSD.
19+
"""
20+
assert P.shape == Q.shape, "Matrix dimensions must match"
21+
22+
# Compute centroids
23+
centroid_P = jnp.mean(P, axis=0)
24+
centroid_Q = jnp.mean(Q, axis=0)
25+
26+
# Optimal translation
27+
t = centroid_Q - centroid_P
28+
29+
# Center the points
30+
p = P - centroid_P
31+
q = Q - centroid_Q
32+
33+
# Compute the covariance matrix
34+
H = jnp.dot(p.T, q)
35+
36+
# SVD
37+
U, S, Vt = jnp.linalg.svd(H)
38+
39+
# Validate right-handed coordinate system
40+
Vt = jnp.where(jnp.linalg.det(jnp.dot(Vt.T, U.T)) < 0.0, -Vt, Vt)
41+
42+
# Optimal rotation
43+
R = jnp.dot(Vt.T, U.T)
44+
45+
return jnp.sqrt(jnp.sum(jnp.square(jnp.dot(p, R.T) - q)) / P.shape[0])

0 commit comments

Comments
 (0)