Skip to content

Commit 8ff6594

Browse files
Merge pull request #134 from colleenjg/cjg-dev
Minor plotting modifications
2 parents 2a89d6f + 119e28b commit 8ff6594

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

ratinabox/Environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,8 @@ def plot_environment(self,
410410
Returns:
411411
fig, ax: the environment figures, can be used for further downstream plotting.
412412
"""
413+
wall_lw = kwargs.get("wall_lw", 4.0) # wall linewidth for 2D environment
414+
413415
if self.dimensionality == "1D":
414416
extent = self.extent
415417
if fig is None and ax is None:
@@ -509,7 +511,7 @@ def plot_environment(self,
509511
[wall[0][0], wall[1][0]],
510512
[wall[0][1], wall[1][1]],
511513
color=ratinabox.GREY,
512-
linewidth=4.0,
514+
linewidth=wall_lw,
513515
solid_capstyle="round",
514516
zorder=2,
515517
)

ratinabox/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ def mountain_plot(
638638
fc = 0.3 * c + (1 - 0.3) * np.array([1, 1, 1]) # convert rgb+alpha to rgb
639639
norm = np.max(np.abs(NbyX)) if norm_by == "max" else norm_by
640640
global_shift = kwargs.get("global_shift", 0) #any additional shift to add to each of the lines
641+
shade_skiprate = kwargs.get("shade_skiprate", 1) # skip rate for plotting shading (1 for every points, etc.)
641642
if norm <= 1e-6: norm=100 #large
642643
NbyX = overlap * NbyX / norm
643644
if fig is None and ax is None:
@@ -654,11 +655,12 @@ def mountain_plot(
654655

655656
zorder = 1
656657
X_ = X.copy()
658+
mask = np.arange(len(X_))[::shade_skiprate]
657659
if nan_bins is not None: X_[nan_bins] = np.nan
658660
for i in range(len(NbyX)):
659661
ax.plot(X_, NbyX[i] + i + 1 + global_shift, c=c, zorder=zorder, lw=linewidth)
660662
zorder -= 0.01
661-
ax.fill_between(X_, NbyX[i] + i + 1 + global_shift, i + 1 + global_shift, color=fc, zorder=zorder, alpha=0.8, linewidth=0, **shade_kwargs)
663+
ax.fill_between(X_[mask], NbyX[i][mask] + i + 1 + global_shift, i + 1 + global_shift, color=fc, zorder=zorder, alpha=0.8, linewidth=0, **shade_kwargs)
662664
zorder -= 0.01
663665
ax.spines["left"].set_bounds(1, len(NbyX))
664666
ax.spines["bottom"].set_position(("outward", 1))

0 commit comments

Comments
 (0)