Skip to content

Commit e6cf011

Browse files
author
Hongyuhe
committed
update
1 parent 13e2782 commit e6cf011

File tree

4 files changed

+9
-13
lines changed

4 files changed

+9
-13
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,6 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
319319
}
320320
return can_cast_dict[from_][to]
321321

322-
def test_bitwise_or(x: array, y: array):
323-
if not paddle.is_tensor(x):
324-
x = paddle.to_tensor(x)
325-
if not paddle.is_tensor(y):
326-
y = paddle.to_tensor(y)
327-
return paddle.bitwise_or(x, y)
328-
329322
# Basic renames
330323
bitwise_invert = paddle.bitwise_not
331324
newaxis = None
@@ -339,7 +332,7 @@ def test_bitwise_or(x: array, y: array):
339332
atan2 = _two_arg(paddle.atan2)
340333
bitwise_and = _two_arg(paddle.bitwise_and)
341334
bitwise_left_shift = _two_arg(paddle.bitwise_left_shift)
342-
bitwise_or = _two_arg(test_bitwise_or)
335+
bitwise_or = _two_arg(paddle.bitwise_or)
343336
bitwise_right_shift = _two_arg(paddle.bitwise_right_shift)
344337
bitwise_xor = _two_arg(paddle.bitwise_xor)
345338
copysign = _two_arg(paddle.copysign)
@@ -527,6 +520,9 @@ def prod(
527520
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
528521

529522
if axis == ():
523+
# We can't upcast uint8 according to the spec because there is no
524+
# paddle.uint64, so at least upcast to int64 which is what sum does
525+
# when axis=None.
530526
if dtype is None:
531527
if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
532528
return x.to(paddle.int64)
@@ -537,8 +533,10 @@ def prod(
537533
return _reduce_multiple_axes(
538534
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
539535
)
540-
536+
537+
541538
if axis is None:
539+
# paddle.prod doesn't support multiple axes
542540
if dtype is None and x.dtype == paddle.int32:
543541
dtype = 'int64'
544542
res = paddle.prod(x, dtype=dtype, **kwargs)

array_api_compat/torch/_aliases.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,6 @@ def triu(x: array, /, *, k: int = 0) -> array:
611611

612612
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
613613
def expand_dims(x: array, /, *, axis: int = 0) -> array:
614-
if axis == 2:
615-
import pdb
616-
pdb.set_trace()
617614
return torch.unsqueeze(x, axis)
618615

619616
def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:

paddle-xfails.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_s
156156
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift
157157
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and
158158
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or
159+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor
159160

160161
# test exceeds the deadline of 800ms
161162
array_api_tests/test_linalg.py::test_pinv

vendor_test/vendored/_compat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
../../array_api_compat
1+
../../array_api_compat/

0 commit comments

Comments
 (0)