Skip to content

Commit 7862d89

Browse files
committed
Allow progress continuation in tps_baselines.py
1 parent 7a30b22 commit 7862d89

File tree

2 files changed

+62
-85
lines changed

2 files changed

+62
-85
lines changed

tps/second_order.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def two_way_shooting(system, trajectory, fixed_length, _dt, key):
149149
return False, new_trajectory, new_velocities
150150

151151

152-
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixed_length=0, warmup=50):
152+
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixed_length=0, warmup=50, stored=None):
153153
# pick an initial trajectory
154154
trajectories = [initial_trajectory]
155155
velocities = []
@@ -165,8 +165,14 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
165165
if fixed_length > 0:
166166
statistics['fixed_length'] = fixed_length
167167

168+
if stored is not None:
169+
trajectories = stored['trajectories']
170+
velocities = stored['velocities']
171+
statistics = stored['statistics']
172+
168173
try:
169-
with tqdm(total=num_paths + warmup, desc='warming up' if warmup > 0 else '') as pbar:
174+
with tqdm(total=num_paths + warmup, initial=len(trajectories) - 1,
175+
desc='warming up' if warmup > 0 else '') as pbar:
170176
while len(trajectories) <= num_paths + warmup:
171177
statistics['num_tries'] += 1
172178
if len(trajectories) > warmup:
@@ -194,7 +200,7 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
194200
except KeyboardInterrupt:
195201
print('SIGINT received, stopping early')
196202
# Fix in case we stop when adding a trajectory
197-
if len(trajectories) > len(velocities):
203+
if len(trajectories) > len(velocities) + 1:
198204
velocities.append(new_velocities)
199205

200206
return trajectories[warmup + 1:], velocities[warmup:], statistics

tps_baseline.py

Lines changed: 53 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
parser.add_argument('--fixed_length', type=int, default=0)
3535
parser.add_argument('--num_steps', type=int, default=10,
3636
help='The number of MD steps taken at once. More takes longer to compile but runs faster in the end.')
37+
parser.add_argument('--resume', action='store_true')
38+
parser.add_argument('--override', action='store_true')
3739

3840

3941
def human_format(num):
@@ -136,7 +138,7 @@ def step_n(step, _x, _v, n, _key):
136138
mdtraj_topology = md.Topology.from_openmm(init_pdb.topology)
137139
phis_psis = phi_psi_from_mdtraj(mdtraj_topology)
138140

139-
savedir = f"out/baselines/alanine-{args.mechanism}"
141+
savedir = f"out/baselines/alanine-{args.mechanism}-{args.fixed_length}"
140142
os.makedirs(savedir, exist_ok=True)
141143

142144
# Construct the mass matrix
@@ -203,7 +205,7 @@ def step_langevin_forward(_x, _v, _key):
203205

204206

205207
@jax.jit
206-
def step_langevin_log_density(_x, _v, _new_x, _new_v):
208+
def step_langevin_log_prob(_x, _v, _new_x, _new_v):
207209
alpha = jnp.exp(-gamma_in_ps * dt_in_ps)
208210
f_scale = (1 - alpha) / gamma_in_ps
209211
new_v_det = alpha * _v + f_scale * -dUdx_fn(_x)
@@ -212,14 +214,14 @@ def step_langevin_log_density(_x, _v, _new_x, _new_v):
212214
return jax.scipy.stats.norm.logpdf(new_v_rand, 0, jnp.sqrt(kbT * (1 - alpha ** 2) / mass)).sum()
213215

214216

215-
def langevin_log_path_density(path_and_velocities):
217+
def langevin_log_path_likelihood(path_and_velocities):
216218
path, velocities = path_and_velocities
217219

218220
log_prob = (-U(path[0]) / kbT).sum()
219221
log_prob += jax.scipy.stats.norm.logpdf(velocities[0], 0, jnp.sqrt(kbT / mass)).sum()
220222

221223
for i in range(1, len(path)):
222-
log_prob += step_langevin_log_density(path[i - 1], velocities[i - 1], path[i], velocities[i])
224+
log_prob += step_langevin_log_prob(path[i - 1], velocities[i - 1], path[i], velocities[i])
223225

224226
return log_prob
225227

@@ -236,34 +238,6 @@ def step_langevin_backward(_x, _v, _key):
236238
return prev_x, prev_v
237239

238240

239-
key = jax.random.PRNGKey(1)
240-
key, velocity_key = jax.random.split(key)
241-
steps = 10_000
242-
243-
trajectory = [A]
244-
_x = trajectory[-1]
245-
_v = jnp.sqrt(kbT / mass) * jax.random.normal(velocity_key, (1, 66))
246-
247-
for i in trange(steps):
248-
key, iter_key = jax.random.split(key)
249-
_x, _v = step_langevin_forward(_x, _v, iter_key)
250-
251-
trajectory.append(_x)
252-
253-
trajectory = jnp.array(trajectory).reshape(-1, 66)
254-
255-
# save_trajectory(mdtraj_topology, trajectory[-1000:], 'simulation.pdb')
256-
257-
# we only need to check whether the last frame contains nan, is it propagates
258-
assert not jnp.isnan(trajectory[-1]).any()
259-
trajectory_phi_psi = phis_psis(trajectory)
260-
261-
plt.title(f"{human_format(steps)} steps @ {temp} K, dt = {human_format(dt)}s")
262-
ramachandran(trajectory_phi_psi)
263-
plt.scatter(phis_psis(A)[0, 0], phis_psis(A)[0, 1], color='red', marker='*')
264-
plt.scatter(phis_psis(B)[0, 0], phis_psis(B)[0, 1], color='green', marker='*')
265-
plt.show()
266-
267241
# Choose a system, either phi psi, or rmsd
268242
# system = tps1.System(
269243
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) < 0.1)),
@@ -273,72 +247,69 @@ def step_langevin_backward(_x, _v, _key):
273247

274248
radius = 20 / deg
275249

276-
system = tps1.FirstOrderSystem(
277-
lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius),
278-
lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius),
279-
step
280-
)
250+
# system = tps1.FirstOrderSystem(
251+
# lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius),
252+
# lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius),
253+
# step
254+
# )
281255

282256
system = tps2.SecondOrderSystem(
283257
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius)),
284258
jax.jit(lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius)),
285259
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
286260
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
287-
# step_langevin_forward,
288-
# step_langevin_backward,
289261
jax.jit(lambda _x, _v, _key: step_n(step_langevin_forward, _x, _v, args.num_steps, _key)),
290262
jax.jit(lambda _x, _v, _key: step_n(step_langevin_backward, _x, _v, args.num_steps, _key)),
291263
jax.jit(lambda key: jnp.sqrt(kbT / mass) * jax.random.normal(key, (1, 66)))
292264
)
293265

294-
print("A", phis_psis(A))
295-
print("B", phis_psis(B))
296-
297-
filter1 = system.start_state(trajectory)
298-
filter2 = system.target_state(trajectory)
299-
300-
plt.title('start')
301-
ramachandran(trajectory_phi_psi[filter1])
302-
plt.show()
303-
304-
plt.title('target')
305-
ramachandran(trajectory_phi_psi[filter2])
306-
plt.show()
307-
308266
initial_trajectory = md.load('./files/AD_A_B_500K_initial_trajectory.pdb').xyz.reshape(-1, 1, 66)
309267
initial_trajectory = [p for p in initial_trajectory]
310268
save_trajectory(mdtraj_topology, jnp.array(initial_trajectory), f'{savedir}/initial_trajectory.pdb')
311269

312-
load = False
313-
if load:
314-
paths = np.load(f'{savedir}/paths.npy', allow_pickle=True)
315-
velocities = np.load(f'{savedir}/velocities.npy', allow_pickle=True)
270+
if args.resume:
271+
paths = [[x for x in p.astype(np.float32)] for p in np.load(f'{savedir}/paths.npy', allow_pickle=True)]
272+
velocities = [[v for v in p.astype(np.float32)] for p in np.load(f'{savedir}/velocities.npy', allow_pickle=True)]
316273
with open(f'{savedir}/stats.json', 'r') as fp:
317274
statistics = json.load(fp)
275+
276+
stored = {
277+
'trajectories': [initial_trajectory] + paths,
278+
'velocities': velocities,
279+
'statistics': statistics
280+
}
281+
else:
282+
if os.path.exists(f'{savedir}/paths.npy') and not args.override:
283+
print(f"The target directory is not empy."
284+
f"Please use --override to overwrite the existing data or --resume to continue.")
285+
exit(1)
286+
287+
stored = None
288+
289+
if args.mechanism == 'one-way-shooting':
290+
mechanism = tps2.one_way_shooting
291+
elif args.mechanism == 'two-way-shooting':
292+
mechanism = tps2.two_way_shooting
318293
else:
319-
if args.mechanism == 'one-way-shooting':
320-
mechanism = tps2.one_way_shooting
321-
elif args.mechanism == 'two-way-shooting':
322-
mechanism = tps2.two_way_shooting
323-
else:
324-
raise ValueError(f"Unknown mechanism {args.mechanism}")
325-
326-
try:
327-
paths, velocities, statistics = tps2.mcmc_shooting(system, mechanism, initial_trajectory,
328-
100, dt_in_ps, jax.random.PRNGKey(1), warmup=0,
329-
fixed_length=args.fixed_length)
330-
# paths = tps2.unguided_md(system, B, 1, key)
331-
paths = [jnp.array(p) for p in paths]
332-
velocities = [jnp.array(p) for p in velocities]
333-
# store paths
334-
np.save(f'{savedir}/paths.npy', np.array(paths, dtype=object), allow_pickle=True)
335-
np.save(f'{savedir}/velocities.npy', np.array(velocities, dtype=object), allow_pickle=True)
336-
# save statistics, which is a dictionary
337-
with open(f'{savedir}/stats.json', 'w') as fp:
338-
json.dump(statistics, fp)
339-
except Exception as e:
340-
print(e)
341-
breakpoint()
294+
raise ValueError(f"Unknown mechanism {args.mechanism}")
295+
296+
try:
297+
paths, velocities, statistics = tps2.mcmc_shooting(system, mechanism, initial_trajectory,
298+
100, dt_in_ps, jax.random.PRNGKey(1), warmup=0,
299+
fixed_length=args.fixed_length,
300+
stored=stored)
301+
# paths = tps2.unguided_md(system, B, 1, key)
302+
paths = [jnp.array(p) for p in paths]
303+
velocities = [jnp.array(p) for p in velocities]
304+
# store paths
305+
np.save(f'{savedir}/paths.npy', np.array(paths, dtype=object), allow_pickle=True)
306+
np.save(f'{savedir}/velocities.npy', np.array(velocities, dtype=object), allow_pickle=True)
307+
# save statistics, which is a dictionary
308+
with open(f'{savedir}/stats.json', 'w') as fp:
309+
json.dump(statistics, fp)
310+
except Exception as e:
311+
print(e)
312+
breakpoint()
342313

343314
print(statistics)
344315
print([len(p) for p in paths])
@@ -368,8 +339,8 @@ def step_langevin_backward(_x, _v, _key):
368339
plt.savefig(f'{savedir}/median_energy.png', bbox_inches='tight')
369340
plt.show()
370341

371-
plot_path_energy(list(zip(paths, velocities)), langevin_log_path_density, reduce=lambda x: x, already_ln=True)
372-
plt.ylabel('Path Density')
342+
plot_path_energy(list(zip(paths, velocities)), langevin_log_path_likelihood, reduce=lambda x: x, already_ln=True)
343+
plt.ylabel('Path Likelihood')
373344
plt.savefig(f'{savedir}/path_density.png', bbox_inches='tight')
374345
plt.show()
375346

0 commit comments

Comments
 (0)