Skip to content

Commit 69752b2

Browse files
committed
Format test_euler_step_openmm.py
1 parent ecec24d commit 69752b2

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

test_euler_step_openmm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
# Obtain xi
3636
xi = jnp.sqrt(2 * kbT / mass / gamma)
3737

38-
3938
# Initialize the potential energy with amber forcefields
4039
ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml')
4140
potentials = ff.createPotential(init_pdb.topology,
@@ -49,6 +48,7 @@
4948
nbList.allocate(init_pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
5049
pairs = nbList.pairs
5150

51+
5252
@jax.jit
5353
def U(_x):
5454
"""
@@ -58,21 +58,26 @@ def U(_x):
5858

5959
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters)
6060

61+
6162
def dUdx_fn_unscaled(_x):
6263
return jax.grad(lambda _x: U(_x).sum())(_x)
6364

65+
6466
dUdx_fn_unscaled = jax.vmap(dUdx_fn_unscaled)
6567
dUdx_fn_unscaled = jax.jit(dUdx_fn_unscaled)
6668

69+
6770
@jax.jit
6871
def dUdx_fn(_x):
6972
return dUdx_fn_unscaled(_x) / mass / gamma
7073

74+
7175
@jax.jit
7276
def step(_x, _key):
7377
"""Perform one step of forward euler"""
7478
return _x - dt * dUdx_fn(_x) + jnp.sqrt(dt) * xi * jax.random.normal(_key, _x.shape)
7579

80+
7681
def step_units(_x, _key):
7782
_x = unit.Quantity(_x.reshape(22, 3), unit.nanometer)
7883
grad = unit.Quantity(value=dUdx_fn_unscaled(_x.value_in_unit(unit.nanometer).reshape(1, 66)).reshape(22, 3),
@@ -95,6 +100,7 @@ def step_units(_x, _key):
95100
new_x = new_x_det + noise
96101
return new_x.value_in_unit(unit.nanometer).reshape(1, 66)
97102

103+
98104
key = jax.random.PRNGKey(1)
99105
key, velocity_key = jax.random.split(key)
100106
steps = 100_000

test_langevin_step_openmm.py

Whitespace-only changes.

0 commit comments

Comments
 (0)