Skip to content

Commit 67aa9ef

Browse files
author
Hongyuhe
committed
updat
1 parent e6cf011 commit 67aa9ef

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
319319
}
320320
return can_cast_dict[from_][to]
321321

322+
322323
# Basic renames
323324
bitwise_invert = paddle.bitwise_not
324325
newaxis = None
@@ -520,23 +521,24 @@ def prod(
520521
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
521522

522523
if axis == ():
523-
# We can't upcast uint8 according to the spec because there is no
524-
# paddle.uint64, so at least upcast to int64 which is what sum does
525-
# when axis=None.
526524
if dtype is None:
525+
# We can't upcast uint8 according to the spec because there is no
526+
# paddle.uint64, so at least upcast to int64 which is what sum does
527+
# when axis=None.
527528
if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
528529
return x.to(paddle.int64)
529530
return x.clone()
530531
return x.to(dtype)
531532

533+
# paddle.prod doesn't support multiple axes
532534
if isinstance(axis, tuple):
533535
return _reduce_multiple_axes(
534536
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
535537
)
536538

537539

538540
if axis is None:
539-
# paddle.prod doesn't support multiple axes
541+
# paddle doesn't support keepdims with axis=None
540542
if dtype is None and x.dtype == paddle.int32:
541543
dtype = 'int64'
542544
res = paddle.prod(x, dtype=dtype, **kwargs)
@@ -1283,6 +1285,7 @@ def floor(x: array, /) -> array:
12831285
def ceil(x: array, /) -> array:
12841286
return paddle.ceil(x).to(x.dtype)
12851287

1288+
12861289
def clip(
12871290
x: array,
12881291
/,
@@ -1357,6 +1360,7 @@ def cumulative_sum(
13571360
"axis must be specified in cumulative_sum for more than one dimension"
13581361
)
13591362
axis = 0
1363+
13601364
res = paddle.cumsum(x, axis=axis, dtype=dtype)
13611365

13621366
# np.cumsum does not support include_initial
@@ -1387,6 +1391,7 @@ def searchsorted(
13871391
right=(side == "right"),
13881392
)
13891393

1394+
13901395
__all__ = [
13911396
"__array_namespace_info__",
13921397
"result_type",

0 commit comments

Comments
 (0)