Skip to content

Commit 019880b

Browse files
committed
Use vmap and jit annotation
1 parent a2a69a9 commit 019880b

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

test_euler_step_openmm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,12 @@ def U(_x):
5959
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters)
6060

6161

62+
@jax.jit
63+
@jax.vmap
6264
def dUdx_fn_unscaled(_x):
6365
return jax.grad(lambda _x: U(_x).sum())(_x)
6466

6567

66-
dUdx_fn_unscaled = jax.vmap(dUdx_fn_unscaled)
67-
dUdx_fn_unscaled = jax.jit(dUdx_fn_unscaled)
68-
69-
7068
@jax.jit
7169
def dUdx_fn(_x):
7270
return dUdx_fn_unscaled(_x) / mass / gamma

test_langevin_step_openmm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,12 @@ def U(_x):
6060
return _U(_x.reshape(22, 3), box, pairs, ff.paramset.parameters)
6161

6262

63+
@jax.jit
64+
@jax.vmap
6365
def dUdx_fn_unscaled(_x):
6466
return jax.grad(lambda _x: U(_x).sum())(_x)
6567

6668

67-
dUdx_fn_unscaled = jax.vmap(dUdx_fn_unscaled)
68-
dUdx_fn_unscaled = jax.jit(dUdx_fn_unscaled)
69-
70-
7169
@jax.jit
7270
def dUdx_fn(_x):
7371
return dUdx_fn_unscaled(_x) / mass / gamma
@@ -122,7 +120,8 @@ def step_langevin_units(_x, _v, _key):
122120
# again, we compare the velocities in the same way as we did with the positions
123121
_v_v1 = jax.random.normal(velocity_key, _x.shape) * jnp.sqrt(kbT / mass)
124122

125-
velocity_variance = unit.Quantity(1 / mass, unit=1 / unit.dalton) * unit.BOLTZMANN_CONSTANT_kB * unit.Quantity(temp, unit=unit.kelvin)
123+
velocity_variance = unit.Quantity(1 / mass, unit=1 / unit.dalton) * unit.BOLTZMANN_CONSTANT_kB * unit.Quantity(temp,
124+
unit=unit.kelvin)
126125
# Although velocity+variance is of the unit J / Da = m^2 / s^2, openmm cannot handle this directly and we need to convert it
127126
velocity_variance_in_si = 1 / physical_constants['unified atomic mass unit'][
128127
0] * velocity_variance.value_in_unit(unit.joule / unit.dalton)

0 commit comments

Comments
 (0)