@@ -1124,3 +1124,66 @@ def test_exact_vmin():
11241124@pytest .mark .flaky
11251125def test_https_imread_smoketest ():
11261126 v = mimage .imread ('https://matplotlib.org/1.5.0/_static/logo2.png' )
1127+
1128+
1129+ # A basic ndarray subclass that implements a quantity
1130+ # It does not implement an entire unit system or all quantity math.
1131+ # There is just enough implemented to test handling of ndarray
1132+ # subclasses.
1133+ class QuantityND (np .ndarray ):
1134+ def __new__ (cls , input_array , units ):
1135+ obj = np .asarray (input_array ).view (cls )
1136+ obj .units = units
1137+ return obj
1138+
1139+ def __array_finalize__ (self , obj ):
1140+ self .units = getattr (obj , "units" , None )
1141+
1142+ def __getitem__ (self , item ):
1143+ units = getattr (self , "units" , None )
1144+ ret = super (QuantityND , self ).__getitem__ (item )
1145+ if isinstance (ret , QuantityND ) or units is not None :
1146+ return QuantityND (ret , units )
1147+ return ret
1148+
1149+ def __array_ufunc__ (self , ufunc , method , * inputs , ** kwargs ):
1150+ func = getattr (ufunc , method )
1151+ if "out" in kwargs :
1152+ raise NotImplementedError
1153+ if len (inputs ) == 1 :
1154+ i0 = inputs [0 ]
1155+ unit = getattr (i0 , "units" , "dimensionless" )
1156+ out_arr = func (np .asarray (i0 ), ** kwargs )
1157+ elif len (inputs ) == 2 :
1158+ i0 = inputs [0 ]
1159+ i1 = inputs [1 ]
1160+ u0 = getattr (i0 , "units" , "dimensionless" )
1161+ u1 = getattr (i1 , "units" , "dimensionless" )
1162+ u0 = u1 if u0 is None else u0
1163+ u1 = u0 if u1 is None else u1
1164+ if ufunc in [np .add , np .subtract ]:
1165+ if u0 != u1 :
1166+ raise ValueError
1167+ unit = u0
1168+ elif ufunc == np .multiply :
1169+ unit = f"{ u0 } *{ u1 } "
1170+ elif ufunc == np .divide :
1171+ unit = f"{ u0 } /({ u1 } )"
1172+ else :
1173+ raise NotImplementedError
1174+ out_arr = func (i0 .view (np .ndarray ), i1 .view (np .ndarray ), ** kwargs )
1175+ else :
1176+ raise NotImplementedError
1177+ if unit is None :
1178+ out_arr = np .array (out_arr )
1179+ else :
1180+ out_arr = QuantityND (out_arr , unit )
1181+ return out_arr
1182+
1183+ def test_imshow_quantitynd ():
1184+ # generate a dummy ndarray subclass
1185+ arr = QuantityND (np .ones ((2 ,2 )), "m" )
1186+ fig , ax = plt .subplots ()
1187+ ax .imshow (arr )
1188+ # executing the draw should not raise an exception
1189+ fig .canvas .draw ()
0 commit comments