Skip to content

Commit d16f6d5

Browse files
author
Vahid Tavanashad
committed
address comments
1 parent 31c6355 commit d16f6d5

File tree

2 files changed

+33
-79
lines changed

2 files changed

+33
-79
lines changed

dpnp/tests/test_nanfunctions.py

Lines changed: 29 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def test_usm_ndarray(self, axis, overwrite_input):
465465
a, axis=axis, overwrite_input=overwrite_input
466466
)
467467
result = dpnp.nanmedian(ia, axis=axis, overwrite_input=overwrite_input)
468+
assert_dtype_allclose(result, expected)
468469

469470
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
470471
@pytest.mark.parametrize(
@@ -497,82 +498,93 @@ def test_error(self):
497498
dpnp.nanmedian(a, axis=1, out=res)
498499

499500

500-
class TestNanProd:
501+
@pytest.mark.parametrize("func", ["nanprod", "nansum"])
502+
class TestNanProdSum:
501503
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])
502504
@pytest.mark.parametrize("keepdims", [False, True])
503505
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
504-
def test_basic(self, axis, keepdims, dtype):
506+
def test_basic(self, func, axis, keepdims, dtype):
505507
a = generate_random_numpy_array((2, 3, 4), dtype=dtype)
506508
a[:, :, 2] = numpy.nan
507509
ia = dpnp.array(a)
508510

509-
expected = numpy.nanprod(a, axis=axis, keepdims=keepdims)
510-
result = dpnp.nanprod(ia, axis=axis, keepdims=keepdims)
511+
expected = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
512+
result = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
511513

512514
assert result.shape == expected.shape
513515
assert_allclose(result, expected, rtol=1e-6)
514516

515517
@pytest.mark.usefixtures("suppress_complex_warning")
516518
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
517519
@pytest.mark.parametrize("in_dt", get_float_complex_dtypes())
518-
@pytest.mark.parametrize("dt", get_all_dtypes(no_bool=True, no_none=True))
519-
def test_dtype(self, in_dt, dt):
520+
@pytest.mark.parametrize("dt", get_all_dtypes())
521+
def test_dtype(self, func, in_dt, dt):
520522
a = generate_random_numpy_array((2, 3, 4), dtype=in_dt)
521523
a[:, :, 2] = numpy.nan
522524
ia = dpnp.array(a)
523525

524-
expected = numpy.nanprod(a, dtype=dt)
525-
result = dpnp.nanprod(ia, dtype=dt)
526+
expected = getattr(numpy, func)(a, dtype=dt)
527+
result = getattr(dpnp, func)(ia, dtype=dt)
526528
assert_dtype_allclose(result, expected)
527529

528530
@pytest.mark.usefixtures(
529531
"suppress_overflow_encountered_in_cast_numpy_warnings"
530532
)
531-
def test_out(self):
533+
def test_out(self, func):
532534
ia = dpnp.arange(1, 7).reshape((2, 3))
533535
ia = ia.astype(dpnp.default_float_type(ia.device))
534536
ia[:, 1] = dpnp.nan
535537
a = dpnp.asnumpy(ia)
536538

537539
# out is dpnp_array
538-
expected = numpy.nanprod(a, axis=0)
540+
expected = getattr(numpy, func)(a, axis=0)
539541
iout = dpnp.empty(expected.shape, dtype=expected.dtype)
540-
result = dpnp.nanprod(ia, axis=0, out=iout)
542+
result = getattr(dpnp, func)(ia, axis=0, out=iout)
541543
assert iout is result
542544
assert_allclose(result, expected)
543545

544546
# out is usm_ndarray
545547
dpt_out = dpt.empty(expected.shape, dtype=expected.dtype)
546-
result = dpnp.nanprod(ia, axis=0, out=dpt_out)
548+
result = getattr(dpnp, func)(ia, axis=0, out=dpt_out)
547549
assert dpt_out is result.get_array()
548550
assert_allclose(result, expected)
549551

550552
# out is a numpy array -> TypeError
551553
iout = numpy.empty_like(expected)
552554
with pytest.raises(TypeError):
553-
dpnp.nanprod(ia, axis=0, out=iout)
555+
getattr(dpnp, func)(ia, axis=0, out=iout)
554556

555557
# incorrect shape for out
556558
iout = dpnp.array(numpy.empty((2, 3)))
557559
with pytest.raises(ValueError):
558-
dpnp.nanprod(ia, axis=0, out=iout)
560+
getattr(dpnp, func)(ia, axis=0, out=iout)
559561

560562
@pytest.mark.usefixtures("suppress_complex_warning")
561563
@pytest.mark.parametrize("in_dt", get_float_complex_dtypes())
562564
@pytest.mark.parametrize("out_dt", get_all_dtypes(no_none=True))
563-
def test_out_dtype(self, in_dt, out_dt):
565+
def test_out_dtype(self, func, in_dt, out_dt):
564566
# if out_dt is unsigned, input cannot be signed otherwise overflow occurs
565567
low = 0 if dpnp.issubdtype(out_dt, dpnp.unsignedinteger) else -5
566568
a = generate_random_numpy_array((2, 3, 4), dtype=in_dt, low=low, high=5)
567569
a[:, :, 2] = numpy.nan
568570
out = numpy.zeros_like(a, shape=(2, 3), dtype=out_dt)
569571
ia, iout = dpnp.array(a), dpnp.array(out)
570572

571-
result = dpnp.nanprod(ia, out=iout, axis=2)
572-
expected = numpy.nanprod(a, out=out, axis=2)
573+
result = getattr(dpnp, func)(ia, out=iout, axis=2)
574+
expected = getattr(numpy, func)(a, out=out, axis=2)
573575
assert_allclose(result, expected, rtol=1e-06)
574576
assert result is iout
575577

578+
@pytest.mark.parametrize("stride", [-1, 2])
579+
def test_strided(self, func, stride):
580+
a = numpy.arange(20.0)
581+
a[::3] = numpy.nan
582+
ia = dpnp.array(a)
583+
584+
result = getattr(dpnp, func)(ia[::stride])
585+
expected = getattr(numpy, func)(a[::stride])
586+
assert_allclose(result, expected)
587+
576588

577589
@pytest.mark.parametrize("func", ["nanstd", "nanvar"])
578590
class TestNanStdVar:
@@ -754,61 +766,3 @@ def test_error(self, func):
754766
# ddof should be an integer or float
755767
with pytest.raises(TypeError):
756768
getattr(dpnp, func)(ia, ddof="1")
757-
758-
759-
class TestNanSum:
760-
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
761-
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)])
762-
@pytest.mark.parametrize("keepdims", [True, False])
763-
def test_basic(self, dtype, axis, keepdims):
764-
ia = dpnp.array([[dpnp.nan, 1, 2], [3, dpnp.nan, 0]], dtype=dtype)
765-
a = dpnp.asnumpy(ia)
766-
767-
expected = numpy.nansum(a, axis=axis, keepdims=keepdims)
768-
result = dpnp.nansum(ia, axis=axis, keepdims=keepdims)
769-
assert_allclose(result, expected)
770-
771-
@pytest.mark.parametrize("dtype", get_complex_dtypes())
772-
def test_complex(self, dtype):
773-
a = generate_random_numpy_array(10, dtype=dtype)
774-
a[::3] = numpy.nan
775-
ia = dpnp.array(a)
776-
777-
expected = numpy.nansum(a)
778-
result = dpnp.nansum(ia)
779-
assert_dtype_allclose(result, expected)
780-
781-
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
782-
@pytest.mark.parametrize("axis", [0, 1])
783-
def test_out(self, dtype, axis):
784-
ia = dpnp.array([[dpnp.nan, 1, 2], [3, dpnp.nan, 0]], dtype=dtype)
785-
a = dpnp.asnumpy(ia)
786-
787-
expected = numpy.nansum(a, axis=axis)
788-
out = dpnp.empty_like(dpnp.asarray(expected))
789-
result = dpnp.nansum(ia, axis=axis, out=out)
790-
assert out is result
791-
assert_dtype_allclose(result, expected)
792-
793-
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
794-
def test_dtype(self, dtype):
795-
ia = dpnp.array([[dpnp.nan, 1, 2], [3, dpnp.nan, 0]])
796-
a = dpnp.asnumpy(ia)
797-
798-
expected = numpy.nansum(a, dtype=dtype)
799-
result = dpnp.nansum(ia, dtype=dtype)
800-
assert_dtype_allclose(result, expected)
801-
802-
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
803-
def test_strided(self, dtype):
804-
ia = dpnp.arange(20, dtype=dtype)
805-
ia[::3] = dpnp.nan
806-
a = dpnp.asnumpy(ia)
807-
808-
result = dpnp.nansum(ia[::-1])
809-
expected = numpy.nansum(a[::-1])
810-
assert_allclose(result, expected)
811-
812-
result = dpnp.nansum(ia[::2])
813-
expected = numpy.nansum(a[::2])
814-
assert_allclose(result, expected)

dpnp/tests/test_statistics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def test_true_rowvar(self):
602602
class TestMaxMin:
603603
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])
604604
@pytest.mark.parametrize("keepdims", [False, True])
605-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=None))
605+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
606606
def test_func(self, func, axis, keepdims, dtype):
607607
a = generate_random_numpy_array((4, 4, 6, 8), dtype=dtype)
608608
ia = dpnp.array(a)
@@ -781,7 +781,7 @@ def test_0d_array(self):
781781
@pytest.mark.parametrize("axis", [None, 0, (0, 1), (0, -2, -1)])
782782
@pytest.mark.parametrize("keepdims", [True, False])
783783
def test_nan(self, axis, keepdims):
784-
a = numpy.random.uniform(-5, 5, 24).reshape(2, 3, 4)
784+
a = generate_random_numpy_array((2, 3, 4))
785785
a[0, 0, 0] = a[-1, -1, -1] = numpy.nan
786786
ia = dpnp.array(a)
787787

@@ -793,7 +793,7 @@ def test_nan(self, axis, keepdims):
793793
@pytest.mark.parametrize("axis", [None, 0, -1, (0, -2, -1)])
794794
@pytest.mark.parametrize("keepdims", [True, False])
795795
def test_overwrite_input(self, axis, keepdims):
796-
a = numpy.random.uniform(-5, 5, 24).reshape(2, 3, 4)
796+
a = generate_random_numpy_array((2, 3, 4))
797797
ia = dpnp.array(a)
798798

799799
b = a.copy()
@@ -812,7 +812,7 @@ def test_overwrite_input(self, axis, keepdims):
812812
@pytest.mark.parametrize("axis", [None, 0, (-1,), [0, 1]])
813813
@pytest.mark.parametrize("overwrite_input", [True, False])
814814
def test_usm_ndarray(self, axis, overwrite_input):
815-
a = numpy.random.uniform(-5, 5, 24).reshape(2, 3, 4)
815+
a = generate_random_numpy_array((2, 3, 4))
816816
ia = dpt.asarray(a)
817817

818818
expected = numpy.median(a, axis=axis, overwrite_input=overwrite_input)

0 commit comments

Comments
 (0)