35
35
# Obtain xi
36
36
xi = jnp .sqrt (2 * kbT / mass / gamma )
37
37
38
-
39
38
# Initialize the potential energy with amber forcefields
40
39
ff = Hamiltonian ('amber14/protein.ff14SB.xml' , 'amber14/tip3p.xml' )
41
40
potentials = ff .createPotential (init_pdb .topology ,
49
48
nbList .allocate (init_pdb .getPositions (asNumpy = True ).value_in_unit (unit .nanometer ))
50
49
pairs = nbList .pairs
51
50
51
+
52
52
@jax .jit
53
53
def U (_x ):
54
54
"""
@@ -58,21 +58,26 @@ def U(_x):
58
58
59
59
return _U (_x .reshape (22 , 3 ), box , pairs , ff .paramset .parameters )
60
60
61
+
61
62
def dUdx_fn_unscaled (_x ):
62
63
return jax .grad (lambda _x : U (_x ).sum ())(_x )
63
64
65
+
64
66
dUdx_fn_unscaled = jax .vmap (dUdx_fn_unscaled )
65
67
dUdx_fn_unscaled = jax .jit (dUdx_fn_unscaled )
66
68
69
+
67
70
@jax .jit
68
71
def dUdx_fn (_x ):
69
72
return dUdx_fn_unscaled (_x ) / mass / gamma
70
73
74
+
71
75
@jax .jit
72
76
def step (_x , _key ):
73
77
"""Perform one step of forward euler"""
74
78
return _x - dt * dUdx_fn (_x ) + jnp .sqrt (dt ) * xi * jax .random .normal (_key , _x .shape )
75
79
80
+
76
81
def step_units (_x , _key ):
77
82
_x = unit .Quantity (_x .reshape (22 , 3 ), unit .nanometer )
78
83
grad = unit .Quantity (value = dUdx_fn_unscaled (_x .value_in_unit (unit .nanometer ).reshape (1 , 66 )).reshape (22 , 3 ),
@@ -95,6 +100,7 @@ def step_units(_x, _key):
95
100
new_x = new_x_det + noise
96
101
return new_x .value_in_unit (unit .nanometer ).reshape (1 , 66 )
97
102
103
+
98
104
key = jax .random .PRNGKey (1 )
99
105
key , velocity_key = jax .random .split (key )
100
106
steps = 100_000
0 commit comments