Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions dpnp/tests/third_party/cupy/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,22 @@ def test_nanargmin_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return xp.nanargmin(a, axis=1)

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose()
def test_nanargmin_out_float_dtype(self, xp, dtype):
a = xp.array([[0.0]])
b = xp.empty((1), dtype="int64")
xp.nanargmin(a, axis=1, out=b)
return b

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_array_equal()
def test_nanargmin_out_int_dtype(self, xp, dtype):
a = xp.array([1, 0])
b = xp.empty((), dtype="int64")
xp.nanargmin(a, out=b)
return b


class TestNanArgMax:

Expand Down Expand Up @@ -623,6 +639,22 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype):
a = testing.shaped_random((0, 1), xp, dtype)
return xp.nanargmax(a, axis=1)

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose()
def test_nanargmax_out_float_dtype(self, xp, dtype):
a = xp.array([[0.0]])
b = xp.empty((1), dtype="int64")
xp.nanargmax(a, axis=1, out=b)
return b

@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_array_equal()
def test_nanargmax_out_int_dtype(self, xp, dtype):
a = xp.array([0, 1])
b = xp.empty((), dtype="int64")
xp.nanargmax(a, out=b)
return b


@testing.parameterize(
*testing.product(
Expand Down Expand Up @@ -771,7 +803,7 @@ def test_invalid_sorter(self):

def test_nonint_sorter(self):
for xp in (numpy, cupy):
x = testing.shaped_arange((12,), xp, xp.float32)
x = testing.shaped_arange((12,), xp, xp.float64)
bins = xp.array([10, 4, 2, 1, 8])
sorter = xp.array([], dtype=xp.float32)
with pytest.raises((TypeError, ValueError)):
Expand Down Expand Up @@ -865,7 +897,7 @@ def test_invalid_sorter(self):

def test_nonint_sorter(self):
for xp in (numpy, cupy):
x = testing.shaped_arange((12,), xp, xp.float32)
x = testing.shaped_arange((12,), xp, xp.float64)
bins = xp.array([10, 4, 2, 1, 8])
sorter = xp.array([], dtype=xp.float32)
with pytest.raises((TypeError, ValueError)):
Expand Down
Loading