Skip to content

Commit 9d95064

Browse files
committed
Implement first order and second order TPS, update baseline for second order
1 parent 0fa2275 commit 9d95064

File tree

6 files changed

+394
-31
lines changed

6 files changed

+394
-31
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ dependencies:
1717
- matplotlib==3.8.2
1818
- rdkit==2023.3.3
1919
- ParmEd==4.2.2
20+
- scikit-image==0.23.2

tps.py renamed to tps/first_order.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
MAX_STEPS = 1_000
66

77

8-
class System:
8+
class FirstOrderSystem:
99
def __init__(self, start_state, target_state, step):
1010
self.start_state = start_state
1111
self.target_state = target_state
@@ -18,7 +18,7 @@ def one_way_shooting(system, trajectory, fixed_length, key):
1818
# pick a random point along the trajectory
1919
point_idx = jax.random.randint(key[0], (1,), 1, len(trajectory) - 1)[0]
2020
# pick a random direction, either forward or backward
21-
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
21+
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
2222

2323
if direction == 0:
2424
trajectory = trajectory[:point_idx + 1]

tps/plot.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
from skimage.draw import line
3+
from tqdm import tqdm
4+
import matplotlib.pyplot as plt
5+
6+
7+
class PeriodicPathHistogram:
8+
def __init__(self, bins=250, interpolate=True, scale=np.pi):
9+
self.bins = bins
10+
self.interpolate = interpolate
11+
self.scale = scale
12+
self.hist = np.zeros((bins, bins))
13+
14+
def add_paths(self, paths: list[np.ndarray], factors: list[float] = None):
15+
for path, factor in tqdm(zip(paths, factors or [1] * len(paths)), total=len(paths)):
16+
self.add_path(path, factor=factor)
17+
18+
def add_path(self, path: np.ndarray, factor: float = 1):
19+
"""
20+
Adds a path to the histogram. The path is a list of 2D points in the range [-scale, scale]
21+
"""
22+
rr, cc = np.array([], dtype=int), np.array([], dtype=int)
23+
24+
if self.interpolate:
25+
for i in range(len(path) - 1):
26+
rr_cur, cc_cur = self._add_path_segment_periodic(path[i], path[i + 1])
27+
rr = np.concatenate([rr, rr_cur])
28+
cc = np.concatenate([cc, cc_cur])
29+
else:
30+
for p in path:
31+
point = ((p + self.scale) / (2 * self.scale) * (self.bins - 1)).astype(int)
32+
cc = np.concatenate([cc, [point[0]]])
33+
rr = np.concatenate([rr, [point[1]]])
34+
35+
# we only add it once for each path, so that overlapping segments are not counted multiple times
36+
self.hist[rr, cc] += factor
37+
38+
def _add_path_segment_periodic(self, start: np.ndarray, stop: np.ndarray):
39+
start = np.array(start)
40+
stop = np.array(stop)
41+
42+
if np.linalg.norm(start - stop) < self.scale:
43+
return self._determine_path_segments(start, stop)
44+
45+
possible_offsets = [
46+
np.array([0, 2 * self.scale]),
47+
np.array([0, -2 * self.scale]),
48+
np.array([2 * self.scale, 0]),
49+
np.array([-2 * self.scale, 0]),
50+
np.array([2 * self.scale, 2 * self.scale]),
51+
np.array([-2 * self.scale, 2 * self.scale]),
52+
np.array([2 * self.scale, -2 * self.scale]),
53+
np.array([-2 * self.scale, -2 * self.scale]),
54+
]
55+
56+
def add_shortest_segment(point, target):
57+
distances = np.array([np.linalg.norm((target + offset) - point) for offset in possible_offsets])
58+
best_offset_idx = np.argmin(distances)
59+
60+
best_target = target + possible_offsets[best_offset_idx]
61+
return self._determine_path_segments(point, best_target)
62+
63+
# just try each possible combination and use the shortest path
64+
rr1, cc1 = add_shortest_segment(start, stop)
65+
rr2, cc2 = add_shortest_segment(stop, start)
66+
return np.concatenate([rr1, rr2]), np.concatenate([cc1, cc2])
67+
68+
def _determine_path_segments(self, start: np.ndarray, stop: np.ndarray):
69+
"""
70+
Start and stop are 2D points in the range [-scale, scale].
71+
This function converts the points into the corresponding bins and then uses a line to connect those points
72+
"""
73+
start = ((start + self.scale) / (2 * self.scale) * (self.bins - 1)).astype(int)
74+
stop = ((stop + self.scale) / (2 * self.scale) * (self.bins - 1)).astype(int)
75+
76+
rr, cc = line(start[1], start[0], stop[1], stop[0])
77+
rr_mask = (rr >= 0) & (rr < self.bins)
78+
cc_mask = (cc >= 0) & (cc < self.bins)
79+
mask = rr_mask & cc_mask
80+
rr, cc = rr[mask], cc[mask]
81+
82+
return rr, cc
83+
84+
def plot(self, density=False, cmin=None, cmax=None, **kwargs):
85+
H = self.hist.copy()
86+
if density:
87+
H /= H.sum() * (2 * self.scale / self.bins) ** 2
88+
89+
if cmin is not None:
90+
H[H < cmin] = None
91+
if cmax is not None:
92+
H[H > cmax] = None
93+
94+
x = np.linspace(self.scale, -self.scale, self.bins)
95+
y = np.linspace(self.scale, -self.scale, self.bins)
96+
xv, yv = np.meshgrid(x, y)
97+
98+
plt.pcolormesh(xv, yv, np.flip(H), **kwargs)
99+
ticks = np.arange(-self.scale, self.scale + self.scale * 0.01, self.scale / 4)
100+
101+
plt.xlim(-self.scale, self.scale)
102+
plt.ylim(-self.scale, self.scale)
103+
plt.xlabel(r"$\phi$")
104+
plt.ylabel(r"$\psi$")
105+
plt.xticks(ticks)
106+
plt.yticks(ticks)

tps/second_order.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from tqdm import tqdm
4+
5+
MAX_STEPS = 2_000
6+
MAX_ABS_VALUE = 5
7+
8+
9+
class SecondOrderSystem:
10+
def __init__(self, start_state, target_state, step_forward, step_backward, sample_velocity):
11+
self.start_state = start_state
12+
self.target_state = target_state
13+
self.step_forward = step_forward
14+
self.step_backward = step_backward
15+
self.sample_velocity = sample_velocity
16+
17+
18+
def one_way_shooting(system, trajectory, fixed_length, key):
19+
key = jax.random.split(key)
20+
21+
# pick a random point along the trajectory
22+
point_idx = jax.random.randint(key[0], (1,), 1, len(trajectory) - 1)[0]
23+
# pick a random direction, either forward or backward
24+
direction = jax.random.randint(key[1], (1,), 0, 2)[0]
25+
26+
# TODO: Fix correct dt in ps
27+
velocity = (trajectory[point_idx] - trajectory[point_idx - 1]) / 0.001
28+
29+
if direction == 0:
30+
trajectory = trajectory[:point_idx + 1]
31+
step_function = system.step_forward
32+
else: # direction == 1:
33+
trajectory = trajectory[point_idx:][::-1]
34+
step_function = system.step_backward
35+
36+
steps = MAX_STEPS if fixed_length == 0 else fixed_length
37+
38+
key, iter_key = jax.random.split(key[3])
39+
while len(trajectory) < steps:
40+
key, iter_key = jax.random.split(key)
41+
point, velocity = step_function(trajectory[-1], velocity, iter_key)
42+
trajectory.append(point)
43+
44+
if jnp.isnan(point).any() or jnp.isnan(velocity).any():
45+
return False, trajectory
46+
47+
# ensure that our trajectory does not explode
48+
if (jnp.abs(point) > MAX_ABS_VALUE).any():
49+
return False, trajectory
50+
51+
if system.start_state(trajectory[0]) and system.target_state(trajectory[-1]):
52+
if fixed_length == 0 or len(trajectory) == fixed_length:
53+
return True, trajectory
54+
return False, trajectory
55+
56+
if system.target_state(trajectory[0]) and system.start_state(trajectory[-1]):
57+
if fixed_length == 0 or len(trajectory) == fixed_length:
58+
return True, trajectory[::-1]
59+
return False, trajectory
60+
61+
return False, trajectory
62+
63+
64+
def two_way_shooting(system, trajectory, fixed_length, key):
65+
key = jax.random.split(key)
66+
67+
# pick a random point along the trajectory
68+
point_idx = jax.random.randint(key[0], (1,), 1, len(trajectory) - 1)[0]
69+
point = trajectory[point_idx]
70+
# simulate forward from the point until max_steps
71+
72+
steps = MAX_STEPS if fixed_length == 0 else fixed_length
73+
74+
initial_velocity = system.sample_velocity(key[1])
75+
76+
key, iter_key = jax.random.split(key[2])
77+
new_trajectory = [point]
78+
79+
velocity = initial_velocity
80+
while len(new_trajectory) < steps:
81+
key, iter_key = jax.random.split(key)
82+
point, velocity = system.step_forward(new_trajectory[-1], velocity, iter_key)
83+
new_trajectory.append(point)
84+
85+
if jnp.isnan(point).any() or jnp.isnan(velocity).any():
86+
return False, trajectory
87+
88+
# ensure that our trajectory does not explode
89+
if (jnp.abs(point) > MAX_ABS_VALUE).any():
90+
return False, trajectory
91+
92+
if system.start_state(point) or system.target_state(point):
93+
break
94+
95+
velocity = initial_velocity
96+
while len(new_trajectory) < steps:
97+
key, iter_key = jax.random.split(key)
98+
point, velocity = system.step_backward(new_trajectory[0], velocity, iter_key)
99+
new_trajectory.insert(0, point)
100+
101+
if jnp.isnan(point).any() or jnp.isnan(velocity).any():
102+
return False, trajectory
103+
104+
# ensure that our trajectory does not explode
105+
if (jnp.abs(point) > MAX_ABS_VALUE).any():
106+
return False, trajectory
107+
108+
if system.start_state(point) or system.target_state(point):
109+
break
110+
111+
# throw away the trajectory if it's not the right length
112+
if fixed_length != 0 and len(new_trajectory) != fixed_length:
113+
return False, trajectory
114+
115+
if system.start_state(new_trajectory[0]) and system.target_state(new_trajectory[-1]):
116+
return True, new_trajectory
117+
118+
if system.target_state(new_trajectory[0]) and system.start_state(new_trajectory[-1]):
119+
return True, new_trajectory[::-1]
120+
121+
return False, trajectory
122+
123+
124+
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, key, fixed_length=0, warmup=50):
125+
# pick an initial trajectory
126+
trajectories = [initial_trajectory]
127+
128+
with tqdm(total=num_paths + warmup, desc='warming up' if warmup > 0 else '') as pbar:
129+
while len(trajectories) <= num_paths + warmup:
130+
if len(trajectories) > warmup:
131+
pbar.set_description('')
132+
133+
key, traj_idx_key, iter_key, accept_key = jax.random.split(key, 4)
134+
traj_idx = jax.random.randint(traj_idx_key, (1,), warmup + 1, len(trajectories))[0]
135+
# during warmup, we want an iterative scheme
136+
traj_idx = traj_idx if traj_idx < len(trajectories) else -1
137+
138+
found, new_trajectory = proposal(system, trajectories[traj_idx], fixed_length, iter_key)
139+
140+
if not found:
141+
continue
142+
143+
ratio = len(trajectories[-1]) / len(new_trajectory)
144+
# The first trajectory might have a very unreasonable length, so we skip it
145+
if len(trajectories) == 1 or jax.random.uniform(accept_key, shape=(1,)) < ratio:
146+
trajectories.append(new_trajectory)
147+
pbar.update(1)
148+
149+
return trajectories[warmup + 1:]
150+
151+
152+
def unguided_md(system, initial_point, num_paths, key, fixed_length=0):
153+
trajectories = []
154+
current_frame = initial_point.clone()
155+
current_trajectory = []
156+
157+
key, velocity_key = jax.random.split(key)
158+
velocity = system.sample_velocity(velocity_key)
159+
160+
with tqdm(total=num_paths) as pbar:
161+
while len(trajectories) < num_paths:
162+
key, iter_key = jax.random.split(key)
163+
next_frame, velocity = system.step_forward(current_frame, velocity, iter_key)
164+
165+
assert not jnp.isnan(next_frame).any()
166+
167+
is_transition = not (system.start_state(next_frame) or system.target_state(next_frame))
168+
if is_transition:
169+
if len(current_trajectory) == 0:
170+
current_trajectory.append(current_frame)
171+
172+
if fixed_length != 0 and len(current_trajectory) > fixed_length:
173+
current_trajectory = []
174+
is_transition = False
175+
else:
176+
current_trajectory.append(next_frame)
177+
elif len(current_trajectory) > 0:
178+
current_trajectory.append(next_frame)
179+
180+
if fixed_length == 0 or len(current_trajectory) == fixed_length:
181+
if system.start_state(current_trajectory[0]) and system.target_state(current_trajectory[-1]):
182+
trajectories.append(current_trajectory)
183+
pbar.update(1)
184+
elif system.target_state(current_trajectory[0]) and system.start_state(current_trajectory[-1]):
185+
trajectories.append(current_trajectory[::-1])
186+
pbar.update(1)
187+
current_trajectory = []
188+
189+
current_frame = next_frame
190+
191+
return trajectories

0 commit comments

Comments
 (0)