Skip to content

Commit 8cd22b4

Browse files
committed
Apply unit conversion early in errorbar().
This allow using normal numpy constructs rather than manually looping and broadcasting. _process_unit_info was already special-handling `data is None` in a few places; the change here only handle the (theoretical) extra case where a custom unit converter would fail to properly pass None through.
1 parent d235b02 commit 8cd22b4

File tree

3 files changed

+33
-46
lines changed

3 files changed

+33
-46
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,27 +3281,19 @@ def errorbar(self, x, y, yerr=None, xerr=None,
32813281
kwargs = {k: v for k, v in kwargs.items() if v is not None}
32823282
kwargs.setdefault('zorder', 2)
32833283

3284-
self._process_unit_info([("x", x), ("y", y)], kwargs, convert=False)
3285-
3286-
# Make sure all the args are iterable; use lists not arrays to preserve
3287-
# units.
3288-
if not np.iterable(x):
3289-
x = [x]
3290-
3291-
if not np.iterable(y):
3292-
y = [y]
3293-
3284+
# Casting to object arrays preserves units.
3285+
if not isinstance(x, np.ndarray):
3286+
x = np.asarray(x, dtype=object)
3287+
if not isinstance(y, np.ndarray):
3288+
y = np.asarray(y, dtype=object)
3289+
if xerr is not None and not isinstance(xerr, np.ndarray):
3290+
xerr = np.asarray(xerr, dtype=object)
3291+
if yerr is not None and not isinstance(yerr, np.ndarray):
3292+
yerr = np.asarray(yerr, dtype=object)
3293+
x, y = np.atleast_1d(x, y) # Make sure all the args are iterable.
32943294
if len(x) != len(y):
32953295
raise ValueError("'x' and 'y' must have the same size")
32963296

3297-
if xerr is not None:
3298-
if not np.iterable(xerr):
3299-
xerr = [xerr] * len(x)
3300-
3301-
if yerr is not None:
3302-
if not np.iterable(yerr):
3303-
yerr = [yerr] * len(y)
3304-
33053297
if isinstance(errorevery, Integral):
33063298
errorevery = (0, errorevery)
33073299
if isinstance(errorevery, tuple):
@@ -3313,10 +3305,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33133305
raise ValueError(
33143306
f'errorevery={errorevery!r} is a not a tuple of two '
33153307
f'integers')
3316-
33173308
elif isinstance(errorevery, slice):
33183309
pass
3319-
33203310
elif not isinstance(errorevery, str) and np.iterable(errorevery):
33213311
# fancy indexing
33223312
try:
@@ -3328,6 +3318,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
33283318
else:
33293319
raise ValueError(
33303320
f"errorevery={errorevery!r} is not a recognized value")
3321+
everymask = np.zeros(len(x), bool)
3322+
everymask[errorevery] = True
33313323

33323324
label = kwargs.pop("label", None)
33333325
kwargs['label'] = '_nolegend_'
@@ -3410,13 +3402,8 @@ def errorbar(self, x, y, yerr=None, xerr=None,
34103402
xlolims = np.broadcast_to(xlolims, len(x)).astype(bool)
34113403
xuplims = np.broadcast_to(xuplims, len(x)).astype(bool)
34123404

3413-
everymask = np.zeros(len(x), bool)
3414-
everymask[errorevery] = True
3415-
3416-
def apply_mask(arrays, mask):
3417-
# Return, for each array in *arrays*, the elements for which *mask*
3418-
# is True, without using fancy indexing.
3419-
return [[*itertools.compress(array, mask)] for array in arrays]
3405+
# Vectorized fancy-indexer.
3406+
def apply_mask(arrays, mask): return [array[mask] for array in arrays]
34203407

34213408
def extract_err(name, err, data, lolims, uplims):
34223409
"""
@@ -3437,24 +3424,14 @@ def extract_err(name, err, data, lolims, uplims):
34373424
Error is only applied on **lower** side when this is True. See
34383425
the note in the main docstring about this parameter's name.
34393426
"""
3440-
try: # Asymmetric error: pair of 1D iterables.
3441-
a, b = err
3442-
iter(a)
3443-
iter(b)
3444-
except (TypeError, ValueError):
3445-
a = b = err # Symmetric error: 1D iterable.
3446-
if np.ndim(a) > 1 or np.ndim(b) > 1:
3427+
try:
3428+
low, high = np.broadcast_to(err, (2, len(data)))
3429+
except ValueError:
34473430
raise ValueError(
3448-
f"{name}err must be a scalar or a 1D or (2, n) array-like")
3449-
# Using list comprehensions rather than arrays to preserve units.
3450-
for e in [a, b]:
3451-
if len(data) != len(e):
3452-
raise ValueError(
3453-
f"The lengths of the data ({len(data)}) and the "
3454-
f"error {len(e)} do not match")
3455-
low = [v if lo else v - e for v, e, lo in zip(data, a, lolims)]
3456-
high = [v if up else v + e for v, e, up in zip(data, b, uplims)]
3457-
return low, high
3431+
f"'{name}err' (shape: {np.shape(err)}) must be a scalar "
3432+
f"or a 1D or (2, n) array-like whose shape matches "
3433+
f"'{name}' (shape: {np.shape(data)})") from None
3434+
return data - low * ~lolims, data + high * ~uplims # low, high
34583435

34593436
if xerr is not None:
34603437
left, right = extract_err('x', xerr, x, xlolims, xuplims)

lib/matplotlib/axes/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,7 +2312,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23122312
----------
23132313
datasets : list
23142314
List of (axis_name, dataset) pairs (where the axis name is defined
2315-
as in `._get_axis_map`.
2315+
as in `._get_axis_map`). Individual datasets can also be None
2316+
(which gets passed through).
23162317
kwargs : dict
23172318
Other parameters from which unit info (i.e., the *xunits*,
23182319
*yunits*, *zunits* (for 3D axes), *runits* and *thetaunits* (for
@@ -2359,7 +2360,8 @@ def _process_unit_info(self, datasets=None, kwargs=None, *, convert=True):
23592360
for dataset_axis_name, data in datasets:
23602361
if dataset_axis_name == axis_name and data is not None:
23612362
axis.update_units(data)
2362-
return [axis_map[axis_name].convert_units(data) if convert else data
2363+
return [axis_map[axis_name].convert_units(data)
2364+
if convert and data is not None else data
23632365
for axis_name, data in datasets]
23642366

23652367
def in_axes(self, mouseevent):

lib/matplotlib/tests/test_units.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ def test_scatter_element0_masked():
166166
fig.canvas.draw()
167167

168168

169+
def test_errorbar_mixed_units():
170+
x = np.arange(10)
171+
y = [datetime(2020, 5, i * 2 + 1) for i in x]
172+
fig, ax = plt.subplots()
173+
ax.errorbar(x, y, timedelta(days=0.5))
174+
fig.canvas.draw()
175+
176+
169177
@check_figures_equal(extensions=["png"])
170178
def test_subclass(fig_test, fig_ref):
171179
class subdate(datetime):

0 commit comments

Comments
 (0)