Skip to content

Commit 82cfea3

Browse files
committed
Add checks for TPS baselines
1 parent 4bf3505 commit 82cfea3

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

tps_baseline.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,22 @@ def step_langevin_backward(_x, _v, _key):
294294
}
295295
else:
296296
if os.path.exists(f'{savedir}/paths.npy') and not args.override:
297-
print(f"The target directory is not empy."
297+
print(f"The target directory is not empty.\n"
298298
f"Please use --override to overwrite the existing data or --resume to continue.")
299299
exit(1)
300300

301301
stored = None
302302

303+
assert ((system.start_state(A) and system.target_state(B))
304+
or (system.start_state(B) and system.target_state(A))), \
305+
'A and B are not in the correct states. Please check your settings.'
306+
303307
if args.mechanism == 'one-way-shooting':
308+
assert (system.start_state(initial_trajectory[0])
309+
or system.target_state(initial_trajectory[0])
310+
or system.start_state(initial_trajectory[-1])
311+
or system.target_state(initial_trajectory[-1])
312+
), 'One-Way shooting requires the initial trajectory to start or end in one of the states.'
304313
mechanism = tps2.one_way_shooting
305314
elif args.mechanism == 'two-way-shooting':
306315
mechanism = tps2.two_way_shooting
@@ -325,6 +334,10 @@ def step_langevin_backward(_x, _v, _key):
325334
print(traceback.format_exc())
326335
breakpoint()
327336

337+
if len(paths) == 0:
338+
print("No paths found.")
339+
exit(1)
340+
328341
print(statistics)
329342

330343
if args.fixed_length == 0:

0 commit comments

Comments
 (0)