Skip to content

Commit 4512926

Browse files
committed
Improve plotting and mueller baselines
1 parent 5fcc35d commit 4512926

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

tps_baseline_mueller.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import numpy as np
1010
import utils.toy_plot_helpers as toy
1111

12-
minima_points = jnp.array([[-0.55828035, 1.44169], [-0.05004308, 0.46666032], [0.62361133, 0.02804632]])
12+
minima_points = jnp.array([[-0.55828035, 1.44169],
13+
#[-0.05004308, 0.46666032],
14+
[0.62361133, 0.02804632]])
1315
A, B = minima_points[None, 0], minima_points[None, 2]
1416

1517

@@ -46,20 +48,23 @@ def interpolate_two_points(start, stop, steps):
4648
return interpolation
4749

4850

49-
plot_energy_surface = partial(toy.plot_energy_surface, U=U, states=zip(['A', 'B', 'C'], minima_points),
50-
xlim=jnp.array((-1.5, 0.9)), ylim=jnp.array((-0.5, 1.7)))
51+
plot_energy_surface = partial(toy.plot_energy_surface, U=U, states=zip(['A', 'B'], minima_points),
52+
xlim=jnp.array((-1.5, 0.9)), ylim=jnp.array((-0.5, 1.7)), alpha=1.0)
5153

5254
if __name__ == '__main__':
55+
variable = True
5356
savedir = f"out/baselines/mueller"
57+
if variable:
58+
savedir += "-variable"
59+
5460
os.makedirs(savedir, exist_ok=True)
5561

5662
num_paths = 1000
5763
xi = 5
5864
dt = 1e-4
5965
T = 275e-4
60-
N = int(T / dt)
61-
initial_trajectory = [t.reshape(1, 2) for t in interpolate(minima_points, 100 if N == 0 else N)]
62-
66+
N = 0 if variable else int(T / dt)
67+
initial_trajectory = [t.reshape(1, 2) for t in interpolate(minima_points, 100 if variable else N)]
6368

6469
@jax.jit
6570
def step(_x, _key):
@@ -97,3 +102,4 @@ def step(_x, _key):
97102
plot_energy_surface(trajectories=paths)
98103
plt.savefig(f'{savedir}/mueller-{name}.pdf', bbox_inches='tight')
99104
plt.show()
105+
plt.clf()

utils/toy_plot_helpers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import matplotlib.pyplot as plt
33

44

5-
def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins=150, levels=30, alpha=0.7):
5+
def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins=150, levels=30, alpha=0.7, radius=0.1):
66
x, y = jnp.linspace(xlim[0], xlim[1], bins), jnp.linspace(ylim[0], ylim[1], bins)
77
x, y = jnp.meshgrid(x, y, indexing='ij')
88
z = U(jnp.stack([x, y], -1).reshape(-1, 2)).reshape([bins, bins])
99

1010
# black and white contour plot
11-
plt.contour(x, y, z, levels=levels, cmap='gray')
11+
plt.contour(x, y, z, levels=levels, colors='black')
1212

1313
plt.xlim(xlim[0], xlim[1])
1414
plt.ylim(ylim[0], ylim[1])
@@ -35,12 +35,13 @@ def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins=
3535
rasterized=True
3636
)
3737

38-
plt.colorbar()
38+
plt.xticks([])
39+
plt.yticks([])
3940

4041
for p in points:
4142
plt.scatter(p[0], p[1], marker='*')
4243

4344
for name, pos in states:
44-
c = plt.Circle(pos, radius=0.1, edgecolor='gray', alpha=alpha, facecolor='white', ls='--', lw=0.7)
45+
c = plt.Circle(pos, radius=radius, edgecolor='gray', alpha=alpha, facecolor='white', ls='--', lw=0.7, zorder=100)
4546
plt.gca().add_patch(c)
46-
plt.gca().annotate(name, xy=pos, ha="center", va="center")
47+
plt.gca().annotate(name, xy=pos, ha="center", va="center", fontsize=14, zorder=101)

0 commit comments

Comments
 (0)