|
| 1 | +import os |
| 2 | +import jax |
| 3 | +import numpy as np |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import jax.numpy as jnp |
| 6 | +import optax |
| 7 | +from flax import linen as nn |
| 8 | +from flax.training import train_state |
| 9 | +from tqdm import trange |
| 10 | + |
| 11 | +# install openmm (from conda) |
| 12 | +import openmm.app as app |
| 13 | +import openmm.unit as unit |
| 14 | +# install dmff (from source) |
| 15 | +from dmff import Hamiltonian, NeighborList |
| 16 | +# install mdtraj |
| 17 | +import mdtraj as md |
| 18 | + |
| 19 | +import tps |
| 20 | +# helper function for plotting in 2D |
| 21 | +from utils.PlotPathsAlanine_jax import PlotPathsAlanine |
| 22 | +from matplotlib import colors |
| 23 | + |
| 24 | + |
| 25 | +def human_format(num): |
| 26 | + """https://stackoverflow.com/a/45846841/4417954""" |
| 27 | + num = float('{:.3g}'.format(num)) |
| 28 | + if num >= 1: |
| 29 | + magnitude = 0 |
| 30 | + while abs(num) >= 1000: |
| 31 | + magnitude += 1 |
| 32 | + num /= 1000.0 |
| 33 | + return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude]) |
| 34 | + else: |
| 35 | + magnitude = 0 |
| 36 | + while abs(num) < 1: |
| 37 | + magnitude += 1 |
| 38 | + num *= 1000.0 |
| 39 | + return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'm', 'µ', 'n', 'p', 'f'][magnitude]) |
| 40 | + |
| 41 | + |
| 42 | +def interpolate(points, steps): |
| 43 | + def interpolate_two_points(start, stop, steps): |
| 44 | + t = jnp.linspace(0, 1, steps + 1).reshape(steps + 1, 1) |
| 45 | + interpolated_tensors = jnp.array(start) * (1 - t) + jnp.array(stop) * t |
| 46 | + return interpolated_tensors |
| 47 | + |
| 48 | + step_size = steps // (len(points) - 1) |
| 49 | + remaining = steps % (len(points) - 1) |
| 50 | + |
| 51 | + interpolation = [] |
| 52 | + for i in range(len(points) - 1): |
| 53 | + cur_step_size = step_size + (1 if i < remaining else 0) |
| 54 | + current = interpolate_two_points(points[i], points[i + 1], cur_step_size) |
| 55 | + interpolation.extend(current if i == 0 else current[1:]) |
| 56 | + |
| 57 | + return interpolation |
| 58 | + |
| 59 | + |
| 60 | +def phis_psis(position, mdtraj_topology): |
| 61 | + traj = md.Trajectory(position.reshape(-1, mdtraj_topology.n_atoms, 3), mdtraj_topology) |
| 62 | + phi = md.compute_phi(traj)[1].squeeze() |
| 63 | + psi = md.compute_psi(traj)[1].squeeze() |
| 64 | + return jnp.array([phi, psi]).T |
| 65 | + |
| 66 | + |
| 67 | +def ramachandran(samples, bins=100, path=None, paths=None): |
| 68 | + if samples is not None: |
| 69 | + plt.hist2d(samples[:, 0], samples[:, 1], bins=bins, norm=colors.LogNorm(), rasterized=True) |
| 70 | + plt.xlim(-np.pi, np.pi) |
| 71 | + plt.ylim(-np.pi, np.pi) |
| 72 | + |
| 73 | + # set ticks |
| 74 | + plt.gca().set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]) |
| 75 | + plt.gca().set_xticklabels([r'$-\pi$', r'$-\frac {\pi} {2}$', '0', r'$\frac {\pi} {2}$', r'$\pi$']) |
| 76 | + |
| 77 | + plt.gca().set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]) |
| 78 | + plt.gca().set_yticklabels([r'$-\pi$', r'$-\frac {\pi} {2}$', '0', r'$\frac {\pi} {2}$', r'$\pi$']) |
| 79 | + |
| 80 | + plt.xlabel(r'$\phi$') |
| 81 | + plt.ylabel(r'$\psi$') |
| 82 | + |
| 83 | + plt.gca().set_aspect('equal', adjustable='box') |
| 84 | + |
| 85 | + def draw_path(_path, **kwargs): |
| 86 | + dist = jnp.sqrt(np.sum(jnp.diff(_path, axis=0) ** 2, axis=1)) |
| 87 | + mask = jnp.hstack([dist > jnp.pi, jnp.array([False])]) |
| 88 | + masked_path_x, masked_path_y = np.ma.MaskedArray(_path[:, 0], mask), np.ma.MaskedArray(_path[:, 1], mask) |
| 89 | + plt.plot(masked_path_x, masked_path_y, **kwargs) |
| 90 | + |
| 91 | + |
| 92 | + if path is not None: |
| 93 | + draw_path(path, color='red') |
| 94 | + |
| 95 | + if paths is not None: |
| 96 | + for path in paths: |
| 97 | + draw_path(path, color='blue') |
| 98 | + |
| 99 | + |
| 100 | +T = 2.0 |
| 101 | +dt = 1.0 * unit.microsecond |
| 102 | +dt = dt.value_in_unit(unit.second) |
| 103 | + |
| 104 | +temp = 298.15 |
| 105 | +temp = 10000 # TODO: remove this |
| 106 | +kbT = 1.380649 * 6.02214076 * 1e-3 * temp |
| 107 | + |
| 108 | +if __name__ == '__main__': |
| 109 | + init_pdb = app.PDBFile("./files/AD_c7eq.pdb") |
| 110 | + target_pdb = app.PDBFile("./files/AD_c7ax.pdb") |
| 111 | + mdtraj_topology = md.Topology.from_openmm(init_pdb.topology) |
| 112 | + |
| 113 | + savedir = f"baselines/alanine" |
| 114 | + os.makedirs(savedir, exist_ok=True) |
| 115 | + |
| 116 | + # Construct the mass matrix |
| 117 | + mass = [a.element.mass.value_in_unit(unit.dalton) for a in init_pdb.topology.atoms()] |
| 118 | + new_mass = [] |
| 119 | + for mass_ in mass: |
| 120 | + for _ in range(3): |
| 121 | + new_mass.append(mass_) |
| 122 | + mass = jnp.array(new_mass) |
| 123 | + # Obtain sigma, gamma is by default 1 |
| 124 | + sigma = jnp.sqrt(2 * kbT / mass) |
| 125 | + |
| 126 | + # Initial and target shape [BS, 66] |
| 127 | + A = jnp.array(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)).reshape(1, -1) |
| 128 | + B = jnp.array(target_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)).reshape(1, -1) |
| 129 | + |
| 130 | + # Initialize the potential energy with amber forcefields |
| 131 | + ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml') |
| 132 | + potentials = ff.createPotential(init_pdb.topology, |
| 133 | + nonbondedMethod=app.NoCutoff, |
| 134 | + nonbondedCutoff=1.0 * unit.nanometers, |
| 135 | + constraints=None, |
| 136 | + ewaldErrorTolerance=0.0005) |
| 137 | + # Create a box used when calling |
| 138 | + # Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap |
| 139 | + box = np.array([[50.0, 0.0, 0.0], [0.0, 50.0, 0.0], [0.0, 0.0, 50.0]]) |
| 140 | + nbList = NeighborList(box, 4.0, potentials.meta["cov_map"]) |
| 141 | + nbList.allocate(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) |
| 142 | + pairs = nbList.pairs |
| 143 | + |
| 144 | + |
| 145 | + def U(_x): |
| 146 | + _U = potentials.getPotentialFunc() |
| 147 | + |
| 148 | + return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters) |
| 149 | + |
| 150 | + |
| 151 | + #TODO: we can introduce gamma here |
| 152 | + def dUdx_fn(_x): |
| 153 | + return jax.grad(lambda _x: U(_x).sum())(_x) / mass |
| 154 | + |
| 155 | + |
| 156 | + dUdx_fn = jax.vmap(dUdx_fn) |
| 157 | + dUdx_fn = jax.jit(dUdx_fn) |
| 158 | + |
| 159 | + @jax.jit |
| 160 | + def step(_x, _key): |
| 161 | + """Perform one step of forward euler""" |
| 162 | + return _x - dt * dUdx_fn(_x) + jnp.sqrt(dt) * sigma * jax.random.normal(_key, _x.shape) |
| 163 | + |
| 164 | + key = jax.random.PRNGKey(1) |
| 165 | + |
| 166 | + trajectory = [A] # or [B] |
| 167 | + _x = trajectory[-1] |
| 168 | + steps = 20_000 |
| 169 | + for i in trange(steps): |
| 170 | + key, iter_key = jax.random.split(key) |
| 171 | + _x = step(_x, iter_key) |
| 172 | + trajectory.append(_x) |
| 173 | + |
| 174 | + trajectory = jnp.array(trajectory).reshape(-1, 66) |
| 175 | + assert not jnp.isnan(trajectory).any() |
| 176 | + trajectory_phi_psi = phis_psis(trajectory, mdtraj_topology) |
| 177 | + |
| 178 | + trajs = None |
| 179 | + for i in range(10000, 11000): |
| 180 | + traj = md.load_pdb('./files/AD_c7eq.pdb') |
| 181 | + traj.xyz = trajectory[i].reshape(22, 3) |
| 182 | + if trajs is None: |
| 183 | + trajs = traj |
| 184 | + else: |
| 185 | + trajs = trajs.join(traj) |
| 186 | + trajs.save(f'{savedir}/ALDP_forward_euler.pdb') |
| 187 | + |
| 188 | + plt.title(f"{human_format(steps)} steps @ {temp} K, dt = {human_format(dt)}s") |
| 189 | + ramachandran(trajectory_phi_psi) |
| 190 | + plt.show() |
| 191 | + |
| 192 | + |
| 193 | + # TODO: this is work in progress. Get some baselines with tps |
| 194 | + system = tps.System( |
| 195 | + jax.jit( |
| 196 | + lambda s: jnp.all(jnp.linalg.norm(A.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)), |
| 197 | + jax.jit( |
| 198 | + lambda s: jnp.all(jnp.linalg.norm(B.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)), |
| 199 | + step |
| 200 | + ) |
| 201 | + |
| 202 | + # initial_trajectory = [t.reshape(1, -1) for t in interpolate([A, B], 100)] |
| 203 | + |
| 204 | + # |
| 205 | + # for i in range(10): |
| 206 | + # key, iter_key = jax.random.split(key) |
| 207 | + # |
| 208 | + # # ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology)) |
| 209 | + # # plt.show() |
| 210 | + # |
| 211 | + # |
| 212 | + # ok, trajectory = tps.one_way_shooting(system, initial_trajectory, 0, key) |
| 213 | + # trajectory = jnp.array(trajectory) |
| 214 | + # trajectory = phis_psis(trajectory, mdtraj_topology) |
| 215 | + # print('ok?', ok) |
| 216 | + # |
| 217 | + # ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology), paths=[trajectory]) |
| 218 | + # plt.show() |
| 219 | + # |
| 220 | + |
| 221 | + # paths = tps.mcmc_shooting(system, tps.one_way_shooting, initial_trajectory, 10, key, warmup=0) |
| 222 | + # paths = [jnp.array(p) for p in paths] |
| 223 | + # |
| 224 | + # print(paths) |
| 225 | + # ramachandran(None, path=[phis_psis(p, mdtraj_topology) for p in paths][-1]) |
| 226 | + # plt.show() |
0 commit comments