Skip to content

Commit aaedff6

Browse files
authored
Fix edge case where vcenter is not properly set for diverging norms (#314)
* add unittest * add to parse cmap * replace unittest with becker's example * restore init of diverging norm * fix import
1 parent 4ed2539 commit aaedff6

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

ultraplot/axes/plot.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2707,6 +2707,13 @@ def _parse_cmap(
27072707

27082708
# Create the continuous normalizer.
27092709
norm = _not_none(norm, "div" if "diverging" in trues else "linear")
2710+
# If using a diverging norm, fair=True, and vcenter not set, default to midpoint
2711+
if norm in ("div", "diverging") or "diverging" in trues:
2712+
fair = norm_kw.get("fair", True) # defaults to True
2713+
vcenter = norm_kw.get("vcenter", 0)
2714+
if fair and vcenter is None and vmin is not None and vmax is not None:
2715+
vcenter = 0.5 * (vmin + vmax)
2716+
norm_kw["vcenter"] = vcenter
27102717
if isinstance(norm, mcolors.Normalize):
27112718
norm.vmin, norm.vmax = vmin, vmax
27122719
else:
@@ -2939,7 +2946,6 @@ def _parse_level_lim(
29392946
f"Incompatible arguments vmin={vmin!r}, vmax={vmax!r}, and "
29402947
"symmetric=True. Ignoring the latter."
29412948
)
2942-
29432949
return vmin, vmax, kwargs
29442950

29452951
def _parse_level_num(

ultraplot/tests/test_color.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import ultraplot as uplt, numpy as np, pytest
2+
3+
4+
@pytest.mark.mpl_image_compare
5+
def test_vcenter_values():
6+
"""
7+
Test that vcenter values are correctly set in colorbars.
8+
"""
9+
rng = np.random.default_rng(seed=10)
10+
mvals = rng.normal(size=(32, 32))
11+
cmap = "spectral"
12+
# The middle and right plot should look the same
13+
# The colors should spread out where the extremes are visible
14+
fig, axs = uplt.subplots(ncols=3, share=0)
15+
for i, ax in enumerate(axs):
16+
specs = {}
17+
if i > 0:
18+
vmin = -0.2
19+
vmax = 2.0
20+
specs = dict(vmin=vmin, vmax=vmax)
21+
if i == 2:
22+
mvals = np.clip(mvals, vmin, vmax)
23+
m = ax.pcolormesh(
24+
mvals,
25+
cmap=cmap,
26+
discrete=False,
27+
**specs,
28+
)
29+
ax.format(
30+
grid=False,
31+
xticklabels=[],
32+
xticks=[],
33+
yticklabels=[],
34+
yticks=[],
35+
)
36+
ax.colorbar(m, loc="r", label=f"{i}")
37+
return fig

0 commit comments

Comments
 (0)