Skip to content

Commit e719a3c

Browse files
committed
Add a test to check ndarray subclass handling
1 parent 50e5f76 commit e719a3c

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

lib/matplotlib/tests/test_image.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,3 +1124,66 @@ def test_exact_vmin():
11241124
@pytest.mark.flaky
11251125
def test_https_imread_smoketest():
11261126
v = mimage.imread('https://matplotlib.org/1.5.0/_static/logo2.png')
1127+
1128+
1129+
# A basic ndarray subclass that implements a quantity
1130+
# It does not implement an entire unit system or all quantity math.
1131+
# There is just enough implemented to test handling of ndarray
1132+
# subclasses.
1133+
class QuantityND(np.ndarray):
1134+
def __new__(cls, input_array, units):
1135+
obj = np.asarray(input_array).view(cls)
1136+
obj.units = units
1137+
return obj
1138+
1139+
def __array_finalize__(self, obj):
1140+
self.units = getattr(obj, "units", None)
1141+
1142+
def __getitem__(self, item):
1143+
units = getattr(self, "units", None)
1144+
ret = super(QuantityND, self).__getitem__(item)
1145+
if isinstance(ret, QuantityND) or units is not None:
1146+
return QuantityND(ret, units)
1147+
return ret
1148+
1149+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
1150+
func = getattr(ufunc, method)
1151+
if "out" in kwargs:
1152+
raise NotImplementedError
1153+
if len(inputs) == 1:
1154+
i0 = inputs[0]
1155+
unit = getattr(i0, "units", "dimensionless")
1156+
out_arr = func(np.asarray(i0), **kwargs)
1157+
elif len(inputs) == 2:
1158+
i0 = inputs[0]
1159+
i1 = inputs[1]
1160+
u0 = getattr(i0, "units", "dimensionless")
1161+
u1 = getattr(i1, "units", "dimensionless")
1162+
u0 = u1 if u0 is None else u0
1163+
u1 = u0 if u1 is None else u1
1164+
if ufunc in [np.add, np.subtract]:
1165+
if u0 != u1:
1166+
raise ValueError
1167+
unit = u0
1168+
elif ufunc == np.multiply:
1169+
unit = f"{u0}*{u1}"
1170+
elif ufunc == np.divide:
1171+
unit = f"{u0}/({u1})"
1172+
else:
1173+
raise NotImplementedError
1174+
out_arr = func(i0.view(np.ndarray), i1.view(np.ndarray), **kwargs)
1175+
else:
1176+
raise NotImplementedError
1177+
if unit is None:
1178+
out_arr = np.array(out_arr)
1179+
else:
1180+
out_arr = QuantityND(out_arr, unit)
1181+
return out_arr
1182+
1183+
def test_imshow_quantitynd():
1184+
# generate a dummy ndarray subclass
1185+
arr = QuantityND(np.ones((2,2)), "m")
1186+
fig, ax = plt.subplots()
1187+
ax.imshow(arr)
1188+
# executing the draw should not raise an exception
1189+
fig.canvas.draw()

0 commit comments

Comments
 (0)