@@ -465,6 +465,7 @@ def test_usm_ndarray(self, axis, overwrite_input):
465465 a , axis = axis , overwrite_input = overwrite_input
466466 )
467467 result = dpnp .nanmedian (ia , axis = axis , overwrite_input = overwrite_input )
468+ assert_dtype_allclose (result , expected )
468469
469470 @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
470471 @pytest .mark .parametrize (
@@ -497,82 +498,93 @@ def test_error(self):
497498 dpnp .nanmedian (a , axis = 1 , out = res )
498499
499500
500- class TestNanProd :
501+ @pytest .mark .parametrize ("func" , ["nanprod" , "nansum" ])
502+ class TestNanProdSum :
501503 @pytest .mark .parametrize ("axis" , [None , 0 , 1 , - 1 , 2 , - 2 , (1 , 2 ), (0 , - 2 )])
502504 @pytest .mark .parametrize ("keepdims" , [False , True ])
503505 @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
504- def test_basic (self , axis , keepdims , dtype ):
506+ def test_basic (self , func , axis , keepdims , dtype ):
505507 a = generate_random_numpy_array ((2 , 3 , 4 ), dtype = dtype )
506508 a [:, :, 2 ] = numpy .nan
507509 ia = dpnp .array (a )
508510
509- expected = numpy . nanprod (a , axis = axis , keepdims = keepdims )
510- result = dpnp . nanprod (ia , axis = axis , keepdims = keepdims )
511+ expected = getattr ( numpy , func ) (a , axis = axis , keepdims = keepdims )
512+ result = getattr ( dpnp , func ) (ia , axis = axis , keepdims = keepdims )
511513
512514 assert result .shape == expected .shape
513515 assert_allclose (result , expected , rtol = 1e-6 )
514516
515517 @pytest .mark .usefixtures ("suppress_complex_warning" )
516518 @pytest .mark .usefixtures ("suppress_invalid_numpy_warnings" )
517519 @pytest .mark .parametrize ("in_dt" , get_float_complex_dtypes ())
518- @pytest .mark .parametrize ("dt" , get_all_dtypes (no_bool = True , no_none = True ))
519- def test_dtype (self , in_dt , dt ):
520+ @pytest .mark .parametrize ("dt" , get_all_dtypes ())
521+ def test_dtype (self , func , in_dt , dt ):
520522 a = generate_random_numpy_array ((2 , 3 , 4 ), dtype = in_dt )
521523 a [:, :, 2 ] = numpy .nan
522524 ia = dpnp .array (a )
523525
524- expected = numpy . nanprod (a , dtype = dt )
525- result = dpnp . nanprod (ia , dtype = dt )
526+ expected = getattr ( numpy , func ) (a , dtype = dt )
527+ result = getattr ( dpnp , func ) (ia , dtype = dt )
526528 assert_dtype_allclose (result , expected )
527529
528530 @pytest .mark .usefixtures (
529531 "suppress_overflow_encountered_in_cast_numpy_warnings"
530532 )
531- def test_out (self ):
533+ def test_out (self , func ):
532534 ia = dpnp .arange (1 , 7 ).reshape ((2 , 3 ))
533535 ia = ia .astype (dpnp .default_float_type (ia .device ))
534536 ia [:, 1 ] = dpnp .nan
535537 a = dpnp .asnumpy (ia )
536538
537539 # out is dpnp_array
538- expected = numpy . nanprod (a , axis = 0 )
540+ expected = getattr ( numpy , func ) (a , axis = 0 )
539541 iout = dpnp .empty (expected .shape , dtype = expected .dtype )
540- result = dpnp . nanprod (ia , axis = 0 , out = iout )
542+ result = getattr ( dpnp , func ) (ia , axis = 0 , out = iout )
541543 assert iout is result
542544 assert_allclose (result , expected )
543545
544546 # out is usm_ndarray
545547 dpt_out = dpt .empty (expected .shape , dtype = expected .dtype )
546- result = dpnp . nanprod (ia , axis = 0 , out = dpt_out )
548+ result = getattr ( dpnp , func ) (ia , axis = 0 , out = dpt_out )
547549 assert dpt_out is result .get_array ()
548550 assert_allclose (result , expected )
549551
550552 # out is a numpy array -> TypeError
551553 iout = numpy .empty_like (expected )
552554 with pytest .raises (TypeError ):
553- dpnp . nanprod (ia , axis = 0 , out = iout )
555+ getattr ( dpnp , func ) (ia , axis = 0 , out = iout )
554556
555557 # incorrect shape for out
556558 iout = dpnp .array (numpy .empty ((2 , 3 )))
557559 with pytest .raises (ValueError ):
558- dpnp . nanprod (ia , axis = 0 , out = iout )
560+ getattr ( dpnp , func ) (ia , axis = 0 , out = iout )
559561
560562 @pytest .mark .usefixtures ("suppress_complex_warning" )
561563 @pytest .mark .parametrize ("in_dt" , get_float_complex_dtypes ())
562564 @pytest .mark .parametrize ("out_dt" , get_all_dtypes (no_none = True ))
563- def test_out_dtype (self , in_dt , out_dt ):
565+ def test_out_dtype (self , func , in_dt , out_dt ):
564566 # if out_dt is unsigned, input cannot be signed otherwise overflow occurs
565567 low = 0 if dpnp .issubdtype (out_dt , dpnp .unsignedinteger ) else - 5
566568 a = generate_random_numpy_array ((2 , 3 , 4 ), dtype = in_dt , low = low , high = 5 )
567569 a [:, :, 2 ] = numpy .nan
568570 out = numpy .zeros_like (a , shape = (2 , 3 ), dtype = out_dt )
569571 ia , iout = dpnp .array (a ), dpnp .array (out )
570572
571- result = dpnp . nanprod (ia , out = iout , axis = 2 )
572- expected = numpy . nanprod (a , out = out , axis = 2 )
573+ result = getattr ( dpnp , func ) (ia , out = iout , axis = 2 )
574+ expected = getattr ( numpy , func ) (a , out = out , axis = 2 )
573575 assert_allclose (result , expected , rtol = 1e-06 )
574576 assert result is iout
575577
578+ @pytest .mark .parametrize ("stride" , [- 1 , 2 ])
579+ def test_strided (self , func , stride ):
580+ a = numpy .arange (20.0 )
581+ a [::3 ] = numpy .nan
582+ ia = dpnp .array (a )
583+
584+ result = getattr (dpnp , func )(ia [::stride ])
585+ expected = getattr (numpy , func )(a [::stride ])
586+ assert_allclose (result , expected )
587+
576588
577589@pytest .mark .parametrize ("func" , ["nanstd" , "nanvar" ])
578590class TestNanStdVar :
@@ -754,61 +766,3 @@ def test_error(self, func):
754766 # ddof should be an integer or float
755767 with pytest .raises (TypeError ):
756768 getattr (dpnp , func )(ia , ddof = "1" )
757-
758-
759- class TestNanSum :
760- @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
761- @pytest .mark .parametrize ("axis" , [None , 0 , 1 , (0 , 1 )])
762- @pytest .mark .parametrize ("keepdims" , [True , False ])
763- def test_basic (self , dtype , axis , keepdims ):
764- ia = dpnp .array ([[dpnp .nan , 1 , 2 ], [3 , dpnp .nan , 0 ]], dtype = dtype )
765- a = dpnp .asnumpy (ia )
766-
767- expected = numpy .nansum (a , axis = axis , keepdims = keepdims )
768- result = dpnp .nansum (ia , axis = axis , keepdims = keepdims )
769- assert_allclose (result , expected )
770-
771- @pytest .mark .parametrize ("dtype" , get_complex_dtypes ())
772- def test_complex (self , dtype ):
773- a = generate_random_numpy_array (10 , dtype = dtype )
774- a [::3 ] = numpy .nan
775- ia = dpnp .array (a )
776-
777- expected = numpy .nansum (a )
778- result = dpnp .nansum (ia )
779- assert_dtype_allclose (result , expected )
780-
781- @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
782- @pytest .mark .parametrize ("axis" , [0 , 1 ])
783- def test_out (self , dtype , axis ):
784- ia = dpnp .array ([[dpnp .nan , 1 , 2 ], [3 , dpnp .nan , 0 ]], dtype = dtype )
785- a = dpnp .asnumpy (ia )
786-
787- expected = numpy .nansum (a , axis = axis )
788- out = dpnp .empty_like (dpnp .asarray (expected ))
789- result = dpnp .nansum (ia , axis = axis , out = out )
790- assert out is result
791- assert_dtype_allclose (result , expected )
792-
793- @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
794- def test_dtype (self , dtype ):
795- ia = dpnp .array ([[dpnp .nan , 1 , 2 ], [3 , dpnp .nan , 0 ]])
796- a = dpnp .asnumpy (ia )
797-
798- expected = numpy .nansum (a , dtype = dtype )
799- result = dpnp .nansum (ia , dtype = dtype )
800- assert_dtype_allclose (result , expected )
801-
802- @pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
803- def test_strided (self , dtype ):
804- ia = dpnp .arange (20 , dtype = dtype )
805- ia [::3 ] = dpnp .nan
806- a = dpnp .asnumpy (ia )
807-
808- result = dpnp .nansum (ia [::- 1 ])
809- expected = numpy .nansum (a [::- 1 ])
810- assert_allclose (result , expected )
811-
812- result = dpnp .nansum (ia [::2 ])
813- expected = numpy .nansum (a [::2 ])
814- assert_allclose (result , expected )
0 commit comments