Skip to content

Commit b5c9438

Browse files
Add out keyword to docstring and update tests
1 parent 62b015a commit b5c9438

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

dpctl/tensor/_elementwise_funcs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@
481481

482482
# B20: ==== NOT_EQUAL (x1, x2)
483483
_not_equal_docstring_ = """
484-
not_equal(x1, x2, order='K')
484+
not_equal(x1, x2, out=None, order='K')
485485
486486
Calculates inequality test results for each element `x1_i` of the
487487
input array `x1` with the respective element `x2_i` of the input array `x2`.
@@ -491,6 +491,12 @@
491491
First input array, expected to have numeric data type.
492492
x2 (usm_ndarray):
493493
Second input array, also expected to have numeric data type.
494+
out ({None, usm_ndarray}, optional):
495+
Output array to populate.
496+
Array have the correct shape and the expected data type.
497+
order ("C","F","A","K", optional):
498+
Memory layout of the newly output array, if parameter `out` is `None`.
499+
Default: "K".
494500
Returns:
495501
usm_narray:
496502
an array containing the result of element-wise inequality comparison.

dpctl/tests/elementwise/test_equal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def test_equal_broadcasting():
127127
r2 = dpt.equal(v, m)
128128
assert (dpt.asnumpy(r2) == expected).all()
129129

130+
r3 = dpt.empty_like(m, dtype="?")
131+
dpt.equal(m, v, r3)
132+
assert (dpt.asnumpy(r3) == expected).all()
133+
130134

131135
@pytest.mark.parametrize("arr_dt", _all_dtypes)
132136
def test_equal_python_scalar(arr_dt):

dpctl/tests/elementwise/test_not_equal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def test_not_equal_broadcasting():
127127
r2 = dpt.not_equal(v, m)
128128
assert (dpt.asnumpy(r2) == expected).all()
129129

130+
r3 = dpt.empty_like(m, dtype="?")
131+
dpt.not_equal(m, v, r3)
132+
assert (dpt.asnumpy(r3) == expected).all()
133+
130134

131135
@pytest.mark.parametrize("arr_dt", _all_dtypes)
132136
def test_not_equal_python_scalar(arr_dt):

0 commit comments

Comments
 (0)