@@ -231,79 +231,77 @@ class MyNDArray(np.ndarray):
231231 assert_ (res .shape == ())
232232
233233
234- class TestNanFunctions_IntTypes :
235-
236- int_types = (np .int8 , np .int16 , np .int32 , np .int64 , np .uint8 ,
237- np .uint16 , np .uint32 , np .uint64 )
234+ @pytest .mark .parametrize (
235+ "dtype" ,
236+ np .typecodes ["AllInteger" ] + np .typecodes ["AllFloat" ] + "O" ,
237+ )
238+ class TestNanFunctions_NumberTypes :
238239
239240 mat = np .array ([127 , 39 , 93 , 87 , 46 ])
240-
241- def integer_arrays (self ):
242- for dtype in self .int_types :
243- yield self .mat .astype (dtype )
244-
245- def test_nanmin (self ):
246- tgt = np .min (self .mat )
247- for mat in self .integer_arrays ():
248- assert_equal (np .nanmin (mat ), tgt )
249-
250- def test_nanmax (self ):
251- tgt = np .max (self .mat )
252- for mat in self .integer_arrays ():
253- assert_equal (np .nanmax (mat ), tgt )
254-
255- def test_nanargmin (self ):
256- tgt = np .argmin (self .mat )
257- for mat in self .integer_arrays ():
258- assert_equal (np .nanargmin (mat ), tgt )
259-
260- def test_nanargmax (self ):
261- tgt = np .argmax (self .mat )
262- for mat in self .integer_arrays ():
263- assert_equal (np .nanargmax (mat ), tgt )
264-
265- def test_nansum (self ):
266- tgt = np .sum (self .mat )
267- for mat in self .integer_arrays ():
268- assert_equal (np .nansum (mat ), tgt )
269-
270- def test_nanprod (self ):
271- tgt = np .prod (self .mat )
272- for mat in self .integer_arrays ():
273- assert_equal (np .nanprod (mat ), tgt )
274-
275- def test_nancumsum (self ):
276- tgt = np .cumsum (self .mat )
277- for mat in self .integer_arrays ():
278- assert_equal (np .nancumsum (mat ), tgt )
279-
280- def test_nancumprod (self ):
281- tgt = np .cumprod (self .mat )
282- for mat in self .integer_arrays ():
283- assert_equal (np .nancumprod (mat ), tgt )
284-
285- def test_nanmean (self ):
286- tgt = np .mean (self .mat )
287- for mat in self .integer_arrays ():
288- assert_equal (np .nanmean (mat ), tgt )
289-
290- def test_nanvar (self ):
291- tgt = np .var (self .mat )
292- for mat in self .integer_arrays ():
293- assert_equal (np .nanvar (mat ), tgt )
294-
295- tgt = np .var (mat , ddof = 1 )
296- for mat in self .integer_arrays ():
297- assert_equal (np .nanvar (mat , ddof = 1 ), tgt )
298-
299- def test_nanstd (self ):
300- tgt = np .std (self .mat )
301- for mat in self .integer_arrays ():
302- assert_equal (np .nanstd (mat ), tgt )
303-
304- tgt = np .std (self .mat , ddof = 1 )
305- for mat in self .integer_arrays ():
306- assert_equal (np .nanstd (mat , ddof = 1 ), tgt )
241+ mat .setflags (write = False )
242+
243+ nanfuncs = {
244+ np .nanmin : np .min ,
245+ np .nanmax : np .max ,
246+ np .nanargmin : np .argmin ,
247+ np .nanargmax : np .argmax ,
248+ np .nansum : np .sum ,
249+ np .nanprod : np .prod ,
250+ np .nancumsum : np .cumsum ,
251+ np .nancumprod : np .cumprod ,
252+ np .nanmean : np .mean ,
253+ np .nanmedian : np .median ,
254+ np .nanvar : np .var ,
255+ np .nanstd : np .std ,
256+ }
257+ nanfunc_ids = [i .__name__ for i in nanfuncs ]
258+
259+ @pytest .mark .parametrize ("nanfunc,func" , nanfuncs .items (), ids = nanfunc_ids )
260+ def test_nanfunc (self , dtype , nanfunc , func ):
261+ if nanfunc is np .nanprod and dtype == "e" :
262+ pytest .xfail (reason = "overflow encountered in reduce" )
263+
264+ mat = self .mat .astype (dtype )
265+ tgt = func (mat )
266+ out = nanfunc (mat )
267+
268+ assert_almost_equal (out , tgt )
269+ if dtype == "O" :
270+ assert type (out ) is type (tgt )
271+ else :
272+ assert out .dtype == tgt .dtype
273+
274+ @pytest .mark .parametrize (
275+ "nanfunc,func" ,
276+ [(np .nanquantile , np .quantile ), (np .nanpercentile , np .percentile )],
277+ ids = ["nanquantile" , "nanpercentile" ],
278+ )
279+ def test_nanfunc_q (self , dtype , nanfunc , func ):
280+ mat = self .mat .astype (dtype )
281+ tgt = func (mat , q = 1 )
282+ out = nanfunc (mat , q = 1 )
283+
284+ assert_almost_equal (out , tgt )
285+ if dtype == "O" :
286+ assert type (out ) is type (tgt )
287+ else :
288+ assert out .dtype == tgt .dtype
289+
290+ @pytest .mark .parametrize (
291+ "nanfunc,func" ,
292+ [(np .nanvar , np .var ), (np .nanstd , np .std )],
293+ ids = ["nanvar" , "nanstd" ],
294+ )
295+ def test_nanfunc_ddof (self , dtype , nanfunc , func ):
296+ mat = self .mat .astype (dtype )
297+ tgt = func (mat , ddof = 1 )
298+ out = nanfunc (mat , ddof = 1 )
299+
300+ assert_almost_equal (out , tgt )
301+ if dtype == "O" :
302+ assert type (out ) is type (tgt )
303+ else :
304+ assert out .dtype == tgt .dtype
307305
308306
309307class SharedNanFunctionsTestsMixin :
0 commit comments