@@ -279,10 +279,10 @@ def step_langevin_backward(_x, _v, _key):
279
279
)
280
280
281
281
system = tps2 .SecondOrderSystem (
282
- # jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
283
- # jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
284
- jax .jit (jax .vmap (lambda s : kabsch_rmsd (A .reshape (22 , 3 ), s .reshape (22 , 3 )) <= 7.5e-2 )),
285
- jax .jit (jax .vmap (lambda s : kabsch_rmsd (B .reshape (22 , 3 ), s .reshape (22 , 3 )) <= 7.5e-2 )),
282
+ jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (A ), radius )),
283
+ jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (B ), radius )),
284
+ # jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
285
+ # jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
286
286
step_langevin_forward ,
287
287
step_langevin_backward ,
288
288
jax .jit (lambda key : jnp .sqrt (kbT / mass ) * jax .random .normal (key , (1 , 66 )))
@@ -306,7 +306,7 @@ def step_langevin_backward(_x, _v, _key):
306
306
initial_trajectory = [p for p in initial_trajectory ]
307
307
save_trajectory (mdtraj_topology , jnp .array (initial_trajectory ), f'{ savedir } /initial_trajectory.pdb' )
308
308
309
- load = True
309
+ load = False
310
310
if load :
311
311
paths = np .load (f'{ savedir } /paths.npy' , allow_pickle = True )
312
312
velocities = np .load (f'{ savedir } /velocities.npy' , allow_pickle = True )
0 commit comments