Skip to content

Commit ce0dadb

Browse files
Use asnumpy().all() in tests to avoid additional kernel launches
1 parent 239e5b1 commit ce0dadb

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,7 @@ def test_nonzero_f_contig():
14351435

14361436
for exp, res in zip(expected_res, result):
14371437
assert_array_equal(dpt.asnumpy(res), exp)
1438-
assert dpt.all(mask[result])
1438+
assert dpt.asnumpy(mask[result]).all()
14391439

14401440

14411441
def test_nonzero_compacting():
@@ -1454,7 +1454,7 @@ def test_nonzero_compacting():
14541454

14551455
for exp, res in zip(expected_res, result):
14561456
assert_array_equal(dpt.asnumpy(res), exp)
1457-
assert dpt.all(mask_view[result])
1457+
assert dpt.asnumpy(mask_view[result]).all()
14581458

14591459

14601460
def test_assign_scalar():

dpctl/tests/test_usm_ndarray_operators.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,14 @@ def test_mat_ops(namespace):
129129
@pytest.mark.parametrize("namespace", [dpt, Dummy()])
130130
def test_comp_ops(namespace):
131131
try:
132-
X = dpt.ones(1, dtype="u8")
132+
X = dpt.asarray(1, dtype="u8")
133133
except dpctl.SyclDeviceCreationError:
134134
pytest.skip("No SYCL devices available")
135135
X._set_namespace(namespace)
136136
assert X.__array_namespace__() is namespace
137-
assert dpt.all(X.__gt__(-1))
138-
assert dpt.all(X.__ge__(-1))
139-
assert not dpt.all(X.__lt__(-1))
140-
assert not dpt.all(X.__le__(-1))
141-
assert not dpt.all(X.__eq__(-1))
142-
assert dpt.all(X.__ne__(-1))
137+
assert X.__gt__(-1)
138+
assert X.__ge__(-1)
139+
assert not X.__lt__(-1)
140+
assert not X.__le__(-1)
141+
assert not X.__eq__(-1)
142+
assert X.__ne__(-1)

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,11 @@ def test_radix_sort_size_1_axis():
354354

355355
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
356356
r3 = dpt.sort(x3, kind="radixsort")
357-
assert dpt.all(r3 == x3)
357+
assert dpt.asnumpy(r3 == x3).all()
358358

359359
x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10))
360360
r4 = dpt.sort(x4, axis=0, kind="radixsort")
361-
assert dpt.all(r4 == x4)
361+
assert dpt.asnumpy(r4 == x4).all()
362362

363363

364364
def test_radix_argsort_size_1_axis():
@@ -370,12 +370,12 @@ def test_radix_argsort_size_1_axis():
370370

371371
x2 = dpt.ones([1], dtype="i1")
372372
r2 = dpt.argsort(x2, kind="radixsort")
373-
assert dpt.all(r2 == 0)
373+
assert dpt.asnumpy(r2 == 0).all()
374374

375375
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
376376
r3 = dpt.argsort(x3, kind="radixsort")
377-
assert dpt.all(r3 == 0)
377+
assert dpt.asnumpy(r3 == 0).all()
378378

379379
x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10))
380380
r4 = dpt.argsort(x4, axis=0, kind="radixsort")
381-
assert dpt.all(r4 == 0)
381+
assert dpt.asnumpy(r4 == 0).all()

0 commit comments

Comments
 (0)