Skip to content

Commit ba2e4c8

Browse files
author
Vahid Tavanashad
committed
use assert_dtype_allclose
1 parent 33fda9d commit ba2e4c8

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tests/test_statistics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def test_basic(self, dtype, size):
313313
expected = numpy.median(a)
314314
result = dpnp.median(ia)
315315

316-
assert_allclose(result, expected)
316+
assert_dtype_allclose(result, expected)
317317

318318
@pytest.mark.parametrize("axis", [None, 0, (-1,), [0, 1], (0, -2, -1)])
319319
@pytest.mark.parametrize("keepdims", [True, False])
@@ -324,7 +324,7 @@ def test_axis(self, axis, keepdims):
324324
expected = numpy.median(a, axis=axis, keepdims=keepdims)
325325
result = dpnp.median(ia, axis=axis, keepdims=keepdims)
326326

327-
assert_allclose(result, expected)
327+
assert_dtype_allclose(result, expected)
328328

329329
@pytest.mark.usefixtures(
330330
"suppress_invalid_numpy_warnings",
@@ -338,7 +338,7 @@ def test_empty(self, axis, shape):
338338

339339
result = dpnp.median(ia, axis=axis)
340340
expected = numpy.median(a, axis=axis)
341-
assert_allclose(expected, result)
341+
assert_dtype_allclose(result, expected)
342342

343343
@pytest.mark.parametrize("dtype", get_all_dtypes())
344344
@pytest.mark.parametrize(
@@ -361,7 +361,7 @@ def test_0d_array(self):
361361

362362
result = dpnp.median(ia)
363363
expected = numpy.median(a)
364-
assert_allclose(expected, result)
364+
assert_dtype_allclose(result, expected)
365365

366366
@pytest.mark.parametrize("axis", [None, 0, (0, 1), (0, -2, -1)])
367367
@pytest.mark.parametrize("keepdims", [True, False])
@@ -373,7 +373,7 @@ def test_nan(self, axis, keepdims):
373373
expected = numpy.median(a, axis=axis, keepdims=keepdims)
374374
result = dpnp.median(ia, axis=axis, keepdims=keepdims)
375375

376-
assert_allclose(result, expected)
376+
assert_dtype_allclose(result, expected)
377377

378378
@pytest.mark.parametrize("axis", [None, 0, -1, (0, -2, -1)])
379379
@pytest.mark.parametrize("keepdims", [True, False])
@@ -392,7 +392,7 @@ def test_overwrite_input(self, axis, keepdims):
392392
assert not numpy.all(a == b)
393393
assert not dpnp.all(ia == ib)
394394

395-
assert_allclose(result, expected)
395+
assert_dtype_allclose(result, expected)
396396

397397
@pytest.mark.parametrize("axis", [None, 0, (-1,), [0, 1]])
398398
@pytest.mark.parametrize("overwrite_input", [True, False])
@@ -403,7 +403,7 @@ def test_usm_ndarray(self, axis, overwrite_input):
403403
expected = numpy.median(a, axis=axis, overwrite_input=overwrite_input)
404404
result = dpnp.median(ia, axis=axis, overwrite_input=overwrite_input)
405405

406-
assert_allclose(result, expected)
406+
assert_dtype_allclose(result, expected)
407407

408408

409409
class TestVar:

0 commit comments

Comments
 (0)