@@ -1124,3 +1124,83 @@ def test_exact_vmin():
11241124@pytest .mark .flaky
11251125def test_https_imread_smoketest ():
11261126 v = mimage .imread ('https://matplotlib.org/1.5.0/_static/logo2.png' )
1127+
1128+
1129+ # A basic ndarray subclass that implements a quantity
1130+ # It does not implement an entire unit system or all quantity math.
1131+ # There is just enough implemented to test handling of ndarray
1132+ # subclasses.
1133+ class QuantityND (np .ndarray ):
1134+ def __new__ (cls , input_array , units ):
1135+ obj = np .asarray (input_array ).view (cls )
1136+ obj .units = units
1137+ return obj
1138+
1139+ def __array_finalize__ (self , obj ):
1140+ self .units = getattr (obj , "units" , None )
1141+
1142+ def __getitem__ (self , item ):
1143+ units = getattr (self , "units" , None )
1144+ ret = super (QuantityND , self ).__getitem__ (item )
1145+ if isinstance (ret , QuantityND ) or units is not None :
1146+ ret = QuantityND (ret , units )
1147+ return ret
1148+
1149+ def __array_ufunc__ (self , ufunc , method , * inputs , ** kwargs ):
1150+ func = getattr (ufunc , method )
1151+ if "out" in kwargs :
1152+ raise NotImplementedError
1153+ if len (inputs ) == 1 :
1154+ i0 = inputs [0 ]
1155+ unit = getattr (i0 , "units" , "dimensionless" )
1156+ out_arr = func (np .asarray (i0 ), ** kwargs )
1157+ elif len (inputs ) == 2 :
1158+ i0 = inputs [0 ]
1159+ i1 = inputs [1 ]
1160+ u0 = getattr (i0 , "units" , "dimensionless" )
1161+ u1 = getattr (i1 , "units" , "dimensionless" )
1162+ u0 = u1 if u0 is None else u0
1163+ u1 = u0 if u1 is None else u1
1164+ if ufunc in [np .add , np .subtract ]:
1165+ if u0 != u1 :
1166+ raise ValueError
1167+ unit = u0
1168+ elif ufunc == np .multiply :
1169+ unit = f"{ u0 } *{ u1 } "
1170+ elif ufunc == np .divide :
1171+ unit = f"{ u0 } /({ u1 } )"
1172+ else :
1173+ raise NotImplementedError
1174+ out_arr = func (i0 .view (np .ndarray ), i1 .view (np .ndarray ), ** kwargs )
1175+ else :
1176+ raise NotImplementedError
1177+ if unit is None :
1178+ out_arr = np .array (out_arr )
1179+ else :
1180+ out_arr = QuantityND (out_arr , unit )
1181+ return out_arr
1182+
1183+ @property
1184+ def v (self ):
1185+ return self .view (np .ndarray )
1186+
1187+
1188+ def test_quantitynd ():
1189+ q = QuantityND ([1 , 2 ], "m" )
1190+ q0 , q1 = q [:]
1191+ assert np .all (q .v == np .asarray ([1 , 2 ]))
1192+ assert q .units == "m"
1193+ assert np .all ((q0 + q1 ).v == np .asarray ([3 ]))
1194+ assert (q0 * q1 ).units == "m*m"
1195+ assert (q1 / q0 ).units == "m/(m)"
1196+ with pytest .raises (ValueError ):
1197+ q0 + QuantityND (1 , "s" )
1198+
1199+
1200+ def test_imshow_quantitynd ():
1201+ # generate a dummy ndarray subclass
1202+ arr = QuantityND (np .ones ((2 , 2 )), "m" )
1203+ fig , ax = plt .subplots ()
1204+ ax .imshow (arr )
1205+ # executing the draw should not raise an exception
1206+ fig .canvas .draw ()
0 commit comments