Skip to content

Commit 1325b4a

Browse files
committed
Use phi psi states in baselines
1 parent e94d743 commit 1325b4a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tps_baseline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,10 @@ def step_langevin_backward(_x, _v, _key):
279279
)
280280

281281
system = tps2.SecondOrderSystem(
282-
# jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
283-
# jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
284-
jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
285-
jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
282+
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
283+
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
284+
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
285+
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
286286
step_langevin_forward,
287287
step_langevin_backward,
288288
jax.jit(lambda key: jnp.sqrt(kbT / mass) * jax.random.normal(key, (1, 66)))
@@ -306,7 +306,7 @@ def step_langevin_backward(_x, _v, _key):
306306
initial_trajectory = [p for p in initial_trajectory]
307307
save_trajectory(mdtraj_topology, jnp.array(initial_trajectory), f'{savedir}/initial_trajectory.pdb')
308308

309-
load = True
309+
load = False
310310
if load:
311311
paths = np.load(f'{savedir}/paths.npy', allow_pickle=True)
312312
velocities = np.load(f'{savedir}/velocities.npy', allow_pickle=True)

0 commit comments

Comments
 (0)