2
2
import matplotlib .pyplot as plt
3
3
4
4
5
- def plot_energy_surface (U , states , xlim , ylim , points = [], trajectories = [], bins = 150 , levels = 30 , alpha = 0.7 ):
5
+ def plot_energy_surface (U , states , xlim , ylim , points = [], trajectories = [], bins = 150 , levels = 30 , alpha = 0.7 , radius = 0.1 ):
6
6
x , y = jnp .linspace (xlim [0 ], xlim [1 ], bins ), jnp .linspace (ylim [0 ], ylim [1 ], bins )
7
7
x , y = jnp .meshgrid (x , y , indexing = 'ij' )
8
8
z = U (jnp .stack ([x , y ], - 1 ).reshape (- 1 , 2 )).reshape ([bins , bins ])
9
9
10
10
# black and white contour plot
11
- plt .contour (x , y , z , levels = levels , cmap = 'gray ' )
11
+ plt .contour (x , y , z , levels = levels , colors = 'black ' )
12
12
13
13
plt .xlim (xlim [0 ], xlim [1 ])
14
14
plt .ylim (ylim [0 ], ylim [1 ])
@@ -35,12 +35,13 @@ def plot_energy_surface(U, states, xlim, ylim, points=[], trajectories=[], bins=
35
35
rasterized = True
36
36
)
37
37
38
- plt .colorbar ()
38
+ plt .xticks ([])
39
+ plt .yticks ([])
39
40
40
41
for p in points :
41
42
plt .scatter (p [0 ], p [1 ], marker = '*' )
42
43
43
44
for name , pos in states :
44
- c = plt .Circle (pos , radius = 0.1 , edgecolor = 'gray' , alpha = alpha , facecolor = 'white' , ls = '--' , lw = 0.7 )
45
+ c = plt .Circle (pos , radius = radius , edgecolor = 'gray' , alpha = alpha , facecolor = 'white' , ls = '--' , lw = 0.7 , zorder = 100 )
45
46
plt .gca ().add_patch (c )
46
- plt .gca ().annotate (name , xy = pos , ha = "center" , va = "center" )
47
+ plt .gca ().annotate (name , xy = pos , ha = "center" , va = "center" , fontsize = 14 , zorder = 101 )
0 commit comments