Skip to content

Commit bc65c38

Browse files
committed
Add function to represent a molecule with internal coordinates
1 parent 1325b4a commit bc65c38

File tree

1 file changed

+383
-0
lines changed

1 file changed

+383
-0
lines changed

utils/internal.py

Lines changed: 383 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from jax import grad, jit
4+
5+
# File was ported from:
6+
# https://github.com/VincentStimper/boltzmann-generators/blob/2b177fc155f533933489b8fce8d6483ebad250d3/boltzgen/internal.py
7+
8+
9+
def calc_bonds(ind1, ind2, coords):
10+
"""Calculate bond lengths
11+
12+
Parameters
13+
----------
14+
ind1 : jnp.ndarray
15+
A n_bond x 3 array of indices for the coordinates of particle 1
16+
ind2 : jnp.ndarray
17+
A n_bond x 3 array of indices for the coordinates of particle 2
18+
coords : jnp.ndarray
19+
A n_batch x n_coord array of flattened input coordinates
20+
"""
21+
p1 = coords[:, ind1]
22+
p2 = coords[:, ind2]
23+
return jnp.linalg.norm(p2 - p1, axis=2)
24+
25+
26+
def calc_angles(ind1, ind2, ind3, coords):
27+
b = coords[:, ind1]
28+
c = coords[:, ind2]
29+
d = coords[:, ind3]
30+
bc = b - c
31+
bc /= jnp.linalg.norm(bc, axis=2, keepdims=True)
32+
cd = d - c
33+
cd /= jnp.linalg.norm(cd, axis=2, keepdims=True)
34+
cos_angle = jnp.sum(bc * cd, axis=2)
35+
angle = jnp.arccos(cos_angle)
36+
return angle
37+
38+
39+
def calc_dihedrals(ind1, ind2, ind3, ind4, coords):
40+
a = coords[:, ind1]
41+
b = coords[:, ind2]
42+
c = coords[:, ind3]
43+
d = coords[:, ind4]
44+
45+
b0 = a - b
46+
b1 = c - b
47+
b1 /= jnp.linalg.norm(b1, axis=2, keepdims=True)
48+
b2 = d - c
49+
50+
v = b0 - jnp.sum(b0 * b1, axis=2, keepdims=True) * b1
51+
w = b2 - jnp.sum(b2 * b1, axis=2, keepdims=True) * b1
52+
x = jnp.sum(v * w, axis=2)
53+
b1xv = jnp.cross(b1, v, axis=2)
54+
y = jnp.sum(b1xv * w, axis=2)
55+
angle = jnp.arctan2(y, x)
56+
return -angle
57+
58+
59+
def reconstruct_cart(cart, ref_atoms, bonds, angles, dihs):
60+
# Get the positions of the 4 reconstructing atoms
61+
p1 = cart[:, ref_atoms[:, 0], :]
62+
p2 = cart[:, ref_atoms[:, 1], :]
63+
p3 = cart[:, ref_atoms[:, 2], :]
64+
65+
bonds = jnp.expand_dims(bonds, axis=2)
66+
angles = jnp.expand_dims(angles, axis=2)
67+
dihs = jnp.expand_dims(dihs, axis=2)
68+
69+
# Reconstruct the position of p4
70+
v1 = p1 - p2
71+
v2 = p1 - p3
72+
73+
n = jnp.cross(v1, v2, axis=2)
74+
n = n / jnp.linalg.norm(n, axis=2, keepdims=True)
75+
nn = jnp.cross(v1, n, axis=2)
76+
nn = nn / jnp.linalg.norm(nn, axis=2, keepdims=True)
77+
78+
n = n * jnp.sin(dihs)
79+
nn = nn * jnp.cos(dihs)
80+
81+
v3 = n + nn
82+
v3 = v3 / jnp.linalg.norm(v3, axis=2, keepdims=True)
83+
v3 = v3 * bonds * jnp.sin(angles)
84+
85+
v1 = v1 / jnp.linalg.norm(v1, axis=2, keepdims=True)
86+
v1 = v1 * bonds * jnp.cos(angles)
87+
88+
# Store the final position in x
89+
new_cart = p1 + v3 - v1
90+
91+
return new_cart
92+
93+
94+
class InternalCoordinateTransform:
95+
def __init__(self, dims, z_indices=None, cart_indices=None, data=None,
96+
ind_circ_dih=[], shift_dih=False,
97+
shift_dih_params={'hist_bins': 100},
98+
default_std={'bond': 0.005, 'angle': 0.15, 'dih': 0.2}):
99+
self.dims = dims
100+
# Setup indexing.
101+
self._setup_indices(z_indices, cart_indices)
102+
self._validate_data(data)
103+
# Setup the mean and standard deviations for each internal coordinate.
104+
transformed = self._fwd(data)
105+
# Normalize
106+
self.default_std = default_std
107+
self.ind_circ_dih = ind_circ_dih
108+
self._setup_mean_bonds(transformed)
109+
transformed = transformed.at[:, self.bond_indices].set(transformed[:, self.bond_indices] - self.mean_bonds)
110+
self._setup_std_bonds(transformed)
111+
transformed = transformed.at[:, self.bond_indices].set(transformed[:, self.bond_indices] / self.std_bonds)
112+
self._setup_mean_angles(transformed)
113+
transformed = transformed.at[:, self.angle_indices].set(transformed[:, self.angle_indices] - self.mean_angles)
114+
self._setup_std_angles(transformed)
115+
transformed = transformed.at[:, self.angle_indices].set(transformed[:, self.angle_indices] / self.std_angles)
116+
self._setup_mean_dih(transformed)
117+
transformed = transformed.at[:, self.dih_indices].set(transformed[:, self.dih_indices] - self.mean_dih)
118+
transformed = self._fix_dih(transformed)
119+
self._setup_std_dih(transformed)
120+
transformed = transformed.at[:, self.dih_indices].set(transformed[:, self.dih_indices] / self.std_dih)
121+
if shift_dih:
122+
val = jnp.linspace(-jnp.pi, jnp.pi,
123+
shift_dih_params['hist_bins'])
124+
for i in self.ind_circ_dih:
125+
dih = transformed[:, self.dih_indices[i]]
126+
dih = dih * self.std_dih[i] + self.mean_dih[i]
127+
dih = (dih + jnp.pi) % (2 * jnp.pi) - jnp.pi
128+
hist = jnp.histogram(dih, bins=shift_dih_params['hist_bins'],
129+
range=(-jnp.pi, jnp.pi))[0]
130+
self.mean_dih = self.mean_dih.at[i].set(val[jnp.argmin(hist)] + jnp.pi)
131+
dih = (dih - self.mean_dih[i]) / self.std_dih[i]
132+
dih = (dih + jnp.pi) % (2 * jnp.pi) - jnp.pi
133+
transformed = transformed.at[:, self.dih_indices[i]].set(dih)
134+
135+
def to_internal(self, x):
136+
trans = self._fwd(x)
137+
trans = trans.at[:, self.bond_indices].set(trans[:, self.bond_indices] - self.mean_bonds)
138+
trans = trans.at[:, self.bond_indices].set(trans[:, self.bond_indices] / self.std_bonds)
139+
trans = trans.at[:, self.angle_indices].set(trans[:, self.angle_indices] - self.mean_angles)
140+
trans = trans.at[:, self.angle_indices].set(trans[:, self.angle_indices] / self.std_angles)
141+
trans = trans.at[:, self.dih_indices].set(trans[:, self.dih_indices] - self.mean_dih)
142+
trans = self._fix_dih(trans)
143+
trans = trans.at[:, self.dih_indices].set(trans[:, self.dih_indices] / self.std_dih)
144+
return trans
145+
146+
def _fwd(self, x):
147+
# we can do everything in parallel...
148+
inds1 = self.inds_for_atom[self.rev_z_indices[:, 1]]
149+
inds2 = self.inds_for_atom[self.rev_z_indices[:, 2]]
150+
inds3 = self.inds_for_atom[self.rev_z_indices[:, 3]]
151+
inds4 = self.inds_for_atom[self.rev_z_indices[:, 0]]
152+
153+
# Calculate the bonds, angles, and torsions for a batch.
154+
bonds = calc_bonds(inds1, inds4, coords=x)
155+
angles = calc_angles(inds2, inds1, inds4, coords=x)
156+
dihedrals = calc_dihedrals(inds3, inds2, inds1, inds4, coords=x)
157+
158+
# Replace the cartesian coordinates with internal coordinates.
159+
x = x.at[:, inds4[:, 0]].set(bonds)
160+
x = x.at[:, inds4[:, 1]].set(angles)
161+
x = x.at[:, inds4[:, 2]].set(dihedrals)
162+
return x
163+
164+
def to_cartesian(self, x):
165+
# Gather all of the atoms represented as Cartesian coordinates.
166+
n_batch = x.shape[0]
167+
cart = x[:, self.init_cart_indices].reshape(n_batch, -1, 3)
168+
169+
# Loop over all of the blocks, where all of the atoms in each block
170+
# can be built in parallel because they only depend on atoms that
171+
# are already Cartesian. `atoms_to_build` lists the `n` atoms
172+
# that can be built as a batch, where the indexing refers to the
173+
# original atom order. `ref_atoms` has size n x 3, where the indexing
174+
# refers to the position in `cart`, rather than the original order.
175+
for block in self.rev_blocks:
176+
atoms_to_build = block[:, 0]
177+
ref_atoms = block[:, 1:]
178+
179+
# Get all of the bonds by retrieving the appropriate columns and
180+
# un-normalizing.
181+
bonds = (
182+
x[:, 3 * atoms_to_build]
183+
* self.std_bonds[self.atom_to_stats[atoms_to_build]]
184+
+ self.mean_bonds[self.atom_to_stats[atoms_to_build]]
185+
)
186+
187+
# Get all of the angles by retrieving the appropriate columns and
188+
# un-normalizing.
189+
angles = (
190+
x[:, 3 * atoms_to_build + 1]
191+
* self.std_angles[self.atom_to_stats[atoms_to_build]]
192+
+ self.mean_angles[self.atom_to_stats[atoms_to_build]]
193+
)
194+
# Get all of the dihedrals by retrieving the appropriate columns and
195+
# un-normalizing.
196+
dihs = (
197+
x[:, 3 * atoms_to_build + 2]
198+
* self.std_dih[self.atom_to_stats[atoms_to_build]]
199+
+ self.mean_dih[self.atom_to_stats[atoms_to_build]]
200+
)
201+
202+
# Fix the dihedrals to lie in [-pi, pi].
203+
dihs = jnp.where(dihs < jnp.pi, dihs + 2 * jnp.pi, dihs)
204+
dihs = jnp.where(dihs > jnp.pi, dihs - 2 * jnp.pi, dihs)
205+
206+
# Compute the Cartesian coordinates for the newly placed atoms.
207+
new_cart = reconstruct_cart(cart, ref_atoms, bonds, angles, dihs)
208+
209+
# Concatenate the Cartesian coordinates for the newly placed
210+
# atoms onto the full set of Cartesian coordinates.
211+
cart = jnp.concatenate([cart, new_cart], axis=1)
212+
# Permute cart back into the original order and flatten.
213+
cart = cart[:, self.rev_perm_inv]
214+
cart = cart.reshape(n_batch, -1)
215+
return cart
216+
217+
def _setup_mean_bonds(self, x):
218+
self.mean_bonds = jnp.mean(x[:, self.bond_indices], axis=0)
219+
220+
def _setup_std_bonds(self, x):
221+
if x.shape[0] > 1:
222+
self.std_bonds = jnp.std(x[:, self.bond_indices], axis=0)
223+
else:
224+
self.std_bonds = jnp.ones_like(self.mean_bonds) * self.default_std['bond']
225+
226+
def _setup_mean_angles(self, x):
227+
self.mean_angles = jnp.mean(x[:, self.angle_indices], axis=0)
228+
229+
def _setup_std_angles(self, x):
230+
if x.shape[0] > 1:
231+
self.std_angles = jnp.std(x[:, self.angle_indices], axis=0)
232+
else:
233+
self.std_angles = jnp.ones_like(self.mean_angles) * self.default_std['angle']
234+
235+
def _setup_mean_dih(self, x):
236+
sin = jnp.mean(jnp.sin(x[:, self.dih_indices]), axis=0)
237+
cos = jnp.mean(jnp.cos(x[:, self.dih_indices]), axis=0)
238+
self.mean_dih = jnp.arctan2(sin, cos)
239+
240+
def _fix_dih(self, x):
241+
dih = x[:, self.dih_indices]
242+
dih = (dih + jnp.pi) % (2 * jnp.pi) - jnp.pi
243+
x = x.at[:, self.dih_indices].set(dih)
244+
return x
245+
246+
def _setup_std_dih(self, x):
247+
if x.shape[0] > 1:
248+
self.std_dih = jnp.std(x.at[:, self.dih_indices], axis=0)
249+
else:
250+
self.std_dih = jnp.ones_like(self.mean_dih) * self.default_std['dih']
251+
if len(self.ind_circ_dih) > 0:
252+
self.std_dih = self.std_dih.at[jnp.array(self.ind_circ_dih)].set(1.)
253+
254+
def _validate_data(self, data):
255+
if data is None:
256+
raise ValueError(
257+
"InternalCoordinateTransform must be supplied with training_data."
258+
)
259+
260+
if len(data.shape) != 2:
261+
raise ValueError("training_data must be n_samples x n_dim array")
262+
263+
n_dim = data.shape[1]
264+
265+
if n_dim != self.dims:
266+
raise ValueError(
267+
f"training_data must have {self.dims} dimensions, not {n_dim}."
268+
)
269+
270+
def _setup_indices(self, z_indices, cart_indices):
271+
n_atoms = self.dims // 3
272+
ind_for_atom = jnp.zeros((n_atoms, 3), dtype=jnp.int32)
273+
for i in range(n_atoms):
274+
ind_for_atom = ind_for_atom.at[i].set([3 * i, 3 * i + 1, 3 * i + 2])
275+
self.inds_for_atom = ind_for_atom
276+
277+
sorted_z_indices = topological_sort(z_indices)
278+
sorted_z_indices = [
279+
[item[0], item[1][0], item[1][1], item[1][2]] for item in sorted_z_indices
280+
]
281+
rev_z_indices = list(reversed(sorted_z_indices))
282+
283+
mod = [item[0] for item in sorted_z_indices]
284+
modified_indices = []
285+
for index in mod:
286+
modified_indices.extend(self.inds_for_atom[index])
287+
bond_indices = list(modified_indices[0::3])
288+
angle_indices = list(modified_indices[1::3])
289+
dih_indices = list(modified_indices[2::3])
290+
291+
self.modified_indices = jnp.array(modified_indices)
292+
self.bond_indices = jnp.array(bond_indices)
293+
self.angle_indices = jnp.array(angle_indices)
294+
self.dih_indices = jnp.array(dih_indices)
295+
self.sorted_z_indices = jnp.array(sorted_z_indices)
296+
self.rev_z_indices = jnp.array(rev_z_indices)
297+
298+
#
299+
# Setup indexing for reverse pass.
300+
#
301+
# First, create an array that maps from an atom index into mean_bonds, std_bonds, etc.
302+
atom_to_stats = jnp.zeros(n_atoms, dtype=jnp.int32)
303+
for i, j in enumerate(mod):
304+
atom_to_stats = atom_to_stats.at[j].set(i)
305+
self.atom_to_stats = atom_to_stats
306+
307+
# Next create permutation vector that is used in the reverse pass. This maps
308+
# from the original atom indexing to the order that the Cartesian coordinates
309+
# will be built in. This will be filled in as we go.
310+
rev_perm = jnp.zeros(n_atoms, dtype=jnp.int32)
311+
self.rev_perm = rev_perm
312+
# Next create the inverse of rev_perm. This will be filled in as we go.
313+
rev_perm_inv = jnp.zeros(n_atoms, dtype=jnp.int32)
314+
self.rev_perm_inv = rev_perm_inv
315+
316+
# Create the list of columns that form our initial Cartesian coordinates.
317+
init_cart_indices = self.inds_for_atom[jnp.array(cart_indices)].reshape(-1)
318+
self.init_cart_indices = init_cart_indices
319+
320+
# Update our permutation vectors for the initial Cartesian atoms.
321+
for i, j in enumerate(cart_indices):
322+
self.rev_perm = self.rev_perm.at[i].set(j)
323+
self.rev_perm_inv = self.rev_perm_inv.at[j].set(i)
324+
325+
# Break Z into blocks, where all of the atoms within a block
326+
# can be built in parallel, because they only depend on
327+
# atoms that are already Cartesian.
328+
all_cart = set(cart_indices)
329+
current_cart_ind = i + 1
330+
blocks = []
331+
while sorted_z_indices:
332+
next_z_indices = []
333+
next_cart = set()
334+
block = []
335+
for atom1, atom2, atom3, atom4 in sorted_z_indices:
336+
if (atom2 in all_cart) and (atom3 in all_cart) and (atom4 in all_cart):
337+
# We can build this atom from existing Cartesian atoms,
338+
# so we add it to the list of Cartesian atoms available for the next block.
339+
next_cart.add(atom1)
340+
341+
# Add this atom to our permutation matrices.
342+
self.rev_perm = self.rev_perm.at[current_cart_ind].set(atom1)
343+
self.rev_perm_inv = self.rev_perm_inv.at[atom1].set(current_cart_ind)
344+
current_cart_ind += 1
345+
346+
# Next, we convert the indices for atoms2-4 from their normal values
347+
# to the appropriate indices to index into the Cartesian array.
348+
atom2_mod = self.rev_perm_inv[atom2]
349+
atom3_mod = self.rev_perm_inv[atom3]
350+
atom4_mod = self.rev_perm_inv[atom4]
351+
352+
# Finally, we append this information to the current block.
353+
block.append([atom1, atom2_mod, atom3_mod, atom4_mod])
354+
else:
355+
# We can't build this atom from existing Cartesian atoms,
356+
# so put it on the list for next time.
357+
next_z_indices.append([atom1, atom2, atom3, atom4])
358+
sorted_z_indices = next_z_indices
359+
all_cart = all_cart.union(next_cart)
360+
block = jnp.array(block)
361+
blocks.append(block)
362+
self.rev_blocks = blocks
363+
364+
365+
def topological_sort(graph_unsorted):
366+
graph_sorted = []
367+
graph_unsorted = dict(graph_unsorted)
368+
369+
while graph_unsorted:
370+
acyclic = False
371+
for node, edges in list(graph_unsorted.items()):
372+
for edge in edges:
373+
if edge in graph_unsorted:
374+
break
375+
else:
376+
acyclic = True
377+
del graph_unsorted[node]
378+
graph_sorted.append((node, edges))
379+
380+
if not acyclic:
381+
raise RuntimeError("A cyclic dependency occured.")
382+
383+
return graph_sorted

0 commit comments

Comments
 (0)