@@ -319,6 +319,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
319
319
}
320
320
return can_cast_dict [from_ ][to ]
321
321
322
+
322
323
# Basic renames
323
324
bitwise_invert = paddle .bitwise_not
324
325
newaxis = None
@@ -520,23 +521,24 @@ def prod(
520
521
dtype = _NP_2_PADDLE_DTYPE [dtype .name ]
521
522
522
523
if axis == ():
523
- # We can't upcast uint8 according to the spec because there is no
524
- # paddle.uint64, so at least upcast to int64 which is what sum does
525
- # when axis=None.
526
524
if dtype is None :
525
+ # We can't upcast uint8 according to the spec because there is no
526
+ # paddle.uint64, so at least upcast to int64 which is what sum does
527
+ # when axis=None.
527
528
if x .dtype in [paddle .int8 , paddle .int16 , paddle .int32 , paddle .uint8 ]:
528
529
return x .to (paddle .int64 )
529
530
return x .clone ()
530
531
return x .to (dtype )
531
532
533
+ # paddle.prod doesn't support multiple axes
532
534
if isinstance (axis , tuple ):
533
535
return _reduce_multiple_axes (
534
536
paddle .prod , x , axis , keepdim = keepdims , dtype = dtype , ** kwargs
535
537
)
536
538
537
539
538
540
if axis is None :
539
- # paddle.prod doesn't support multiple axes
541
+ # paddle doesn't support keepdims with axis=None
540
542
if dtype is None and x .dtype == paddle .int32 :
541
543
dtype = 'int64'
542
544
res = paddle .prod (x , dtype = dtype , ** kwargs )
@@ -1283,6 +1285,7 @@ def floor(x: array, /) -> array:
1283
1285
def ceil (x : array , / ) -> array :
1284
1286
return paddle .ceil (x ).to (x .dtype )
1285
1287
1288
+
1286
1289
def clip (
1287
1290
x : array ,
1288
1291
/ ,
@@ -1357,6 +1360,7 @@ def cumulative_sum(
1357
1360
"axis must be specified in cumulative_sum for more than one dimension"
1358
1361
)
1359
1362
axis = 0
1363
+
1360
1364
res = paddle .cumsum (x , axis = axis , dtype = dtype )
1361
1365
1362
1366
# np.cumsum does not support include_initial
@@ -1387,6 +1391,7 @@ def searchsorted(
1387
1391
right = (side == "right" ),
1388
1392
)
1389
1393
1394
+
1390
1395
__all__ = [
1391
1396
"__array_namespace_info__" ,
1392
1397
"result_type" ,
0 commit comments