@@ -462,6 +462,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
462
462
out = paddle .unsqueeze (out , a )
463
463
return out
464
464
465
+ _NP_2_PADDLE_DTYPE = {
466
+ "BOOL" : 'bool' ,
467
+ "UINT8" : 'uint8' ,
468
+ "INT8" : 'int8' ,
469
+ "INT16" : 'int16' ,
470
+ "INT32" : 'int32' ,
471
+ "INT64" : 'int64' ,
472
+ "FLOAT16" : 'float16' ,
473
+ "BFLOAT16" : 'bfloat16' ,
474
+ "FLOAT32" : 'float32' ,
475
+ "FLOAT64" : 'float64' ,
476
+ "COMPLEX128" : 'complex128' ,
477
+ "COMPLEX64" : 'complex64' ,
478
+ }
465
479
466
480
def prod (
467
481
x : array ,
@@ -476,7 +490,36 @@ def prod(
476
490
x = paddle .to_tensor (x )
477
491
ndim = x .ndim
478
492
479
- # below because it still needs to upcast.
493
+ # fix reducing on the zero dimension
494
+ if x .numel () == 0 :
495
+ if dtype is not None :
496
+ output_dtype = _NP_2_PADDLE_DTYPE [dtype .name ]
497
+ else :
498
+ if x .dtype == paddle .bool :
499
+ output_dtype = paddle .int64
500
+ else :
501
+ output_dtype = x .dtype
502
+
503
+ if axis is None :
504
+ return paddle .to_tensor (1 , dtype = output_dtype )
505
+
506
+ if keepdims :
507
+ output_shape = list (x .shape )
508
+ if isinstance (axis , int ):
509
+ axis = (axis ,)
510
+ for ax in axis :
511
+ output_shape [ax ] = 1
512
+ else :
513
+ output_shape = [dim for i , dim in enumerate (x .shape ) if i not in (axis if isinstance (axis , tuple ) else [axis ])]
514
+ if not output_shape :
515
+ return paddle .to_tensor (1 , dtype = output_dtype )
516
+
517
+ return paddle .ones (output_shape , dtype = output_dtype )
518
+
519
+
520
+ if dtype is not None :
521
+ dtype = _NP_2_PADDLE_DTYPE [dtype .name ]
522
+
480
523
if axis == ():
481
524
if dtype is None :
482
525
# We can't upcast uint8 according to the spec because there is no
@@ -492,13 +535,17 @@ def prod(
492
535
return _reduce_multiple_axes (
493
536
paddle .prod , x , axis , keepdim = keepdims , dtype = dtype , ** kwargs
494
537
)
538
+
539
+
495
540
if axis is None :
496
541
# paddle doesn't support keepdims with axis=None
542
+ if dtype is None and x .dtype == paddle .int32 :
543
+ dtype = 'int64'
497
544
res = paddle .prod (x , dtype = dtype , ** kwargs )
498
545
res = _axis_none_keepdims (res , ndim , keepdims )
499
546
return res
500
-
501
- return paddle .prod (x , axis , dtype = dtype , keepdim = keepdims , ** kwargs )
547
+
548
+ return paddle .prod (x , axis = axis , keepdims = keepdims , dtype = dtype , ** kwargs )
502
549
503
550
504
551
def sum (
@@ -747,7 +794,17 @@ def roll(
747
794
def nonzero (x : array , / , ** kwargs ) -> Tuple [array , ...]:
748
795
if x .ndim == 0 :
749
796
raise ValueError ("nonzero() does not support zero-dimensional arrays" )
750
- return paddle .nonzero (x , as_tuple = True , ** kwargs )
797
+
798
+ if paddle .is_floating_point (x ) or paddle .is_complex (x ) :
799
+ # Use paddle.isclose() to determine which elements are
800
+ # "close enough" to zero.
801
+ zero_tensor = paddle .zeros (shape = x .shape ,dtype = x .dtype )
802
+ is_zero_mask = paddle .isclose (x , zero_tensor )
803
+ is_nonzero_mask = paddle .logical_not (is_zero_mask )
804
+ return paddle .nonzero (is_nonzero_mask , as_tuple = True , ** kwargs )
805
+
806
+ else :
807
+ return paddle .nonzero (x , as_tuple = True , ** kwargs )
751
808
752
809
753
810
def where (condition : array , x1 : array , x2 : array , / ) -> array :
@@ -832,7 +889,7 @@ def eye(
832
889
if n_cols is None :
833
890
n_cols = n_rows
834
891
z = paddle .zeros ([n_rows , n_cols ], dtype = dtype , ** kwargs ).to (device )
835
- if abs (k ) <= n_rows + n_cols :
892
+ if n_rows > 0 and n_cols > 0 and abs (k ) <= n_rows + n_cols :
836
893
z .diagonal (k ).fill_ (1 )
837
894
return z
838
895
@@ -867,7 +924,11 @@ def full(
867
924
) -> array :
868
925
if isinstance (shape , int ):
869
926
shape = (shape ,)
870
-
927
+ if dtype is None :
928
+ if isinstance (fill_value , (bool )):
929
+ dtype = "bool"
930
+ elif isinstance (fill_value , int ):
931
+ dtype = 'int64'
871
932
return paddle .full (shape , fill_value , dtype = dtype , ** kwargs ).to (device )
872
933
873
934
@@ -914,6 +975,8 @@ def triu(x: array, /, *, k: int = 0) -> array:
914
975
915
976
916
977
def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
978
+ if axis < - x .ndim - 1 or axis > x .ndim :
979
+ raise IndexError (f"Axis { axis } is out of bounds for array of dimension { x .ndim } " )
917
980
return paddle .unsqueeze (x , axis )
918
981
919
982
@@ -973,6 +1036,22 @@ def unique_values(x: array) -> array:
973
1036
974
1037
def matmul (x1 : array , x2 : array , / , ** kwargs ) -> array :
975
1038
# paddle.matmul doesn't type promote (but differently from _fix_promotion)
1039
+ d1 = x1 .ndim
1040
+ d2 = x2 .ndim
1041
+
1042
+ if d1 == 0 or d2 == 0 :
1043
+ raise ValueError ("matmul does not support 0-D (scalar) inputs." )
1044
+
1045
+ k1 = x1 .shape [- 1 ]
1046
+
1047
+ if d2 == 1 :
1048
+ k2 = x2 .shape [0 ]
1049
+ else :
1050
+ k2 = x2 .shape [- 2 ]
1051
+
1052
+ if k1 != k2 :
1053
+ raise ValueError (f"Shapes { x1 .shape } and { x2 .shape } are not aligned for matmul: "
1054
+ f"{ k1 } (dim -1 of x1) != { k2 } (dim -2 of x2)" )
976
1055
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
977
1056
return paddle .matmul (x1 , x2 , ** kwargs )
978
1057
@@ -988,7 +1067,36 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:
988
1067
989
1068
990
1069
def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
991
- x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
1070
+ shape1 = x1 .shape
1071
+ shape2 = x2 .shape
1072
+ rank1 = len (shape1 )
1073
+ rank2 = len (shape2 )
1074
+ if rank1 == 0 or rank2 == 0 :
1075
+ raise ValueError (
1076
+ f"Vector dot product requires non-scalar inputs (rank > 0). "
1077
+ f"Got ranks { rank1 } and { rank2 } for shapes { shape1 } and { shape2 } ."
1078
+ )
1079
+ try :
1080
+ norm_axis1 = axis if axis >= 0 else rank1 + axis
1081
+ if not (0 <= norm_axis1 < rank1 ):
1082
+ raise IndexError # Axis out of bounds for x1
1083
+ norm_axis2 = axis if axis >= 0 else rank2 + axis
1084
+ if not (0 <= norm_axis2 < rank2 ):
1085
+ raise IndexError # Axis out of bounds for x2
1086
+ size1 = shape1 [norm_axis1 ]
1087
+ size2 = shape2 [norm_axis2 ]
1088
+ except IndexError :
1089
+ raise ValueError (
1090
+ f"Axis { axis } is out of bounds for input shapes { shape1 } (rank { rank1 } ) "
1091
+ f"and/or { shape2 } (rank { rank2 } )."
1092
+ )
1093
+
1094
+ if size1 != size2 :
1095
+ raise ValueError (
1096
+ f"Inputs must have the same dimension size along the reduction axis ({ axis } ). "
1097
+ f"Got shapes { shape1 } and { shape2 } , with sizes { size1 } and { size2 } "
1098
+ f"along the normalized axis { norm_axis1 } and { norm_axis2 } respectively."
1099
+ )
992
1100
return paddle .linalg .vecdot (x1 , x2 , axis = axis )
993
1101
994
1102
@@ -1063,21 +1171,39 @@ def is_complex(dtype):
1063
1171
1064
1172
1065
1173
def take (x : array , indices : array , / , * , axis : Optional [int ] = None , ** kwargs ) -> array :
1066
- if axis is None :
1174
+ _axis = axis
1175
+ if _axis is None :
1067
1176
if x .ndim != 1 :
1068
- raise ValueError ("axis must be specified when ndim > 1" )
1069
- axis = 0
1070
- return paddle .index_select (x , axis , indices , ** kwargs )
1177
+ raise ValueError ("axis must be specified when x.ndim > 1" )
1178
+ _axis = 0
1179
+ elif not isinstance (_axis , int ):
1180
+ raise TypeError (f"axis must be an integer, but received { type (_axis )} " )
1181
+
1182
+ if not (- x .ndim <= _axis < x .ndim ):
1183
+ raise IndexError (f"axis { _axis } is out of bounds for tensor of dimension { x .ndim } " )
1184
+
1185
+ if isinstance (indices , paddle .Tensor ):
1186
+ indices_tensor = indices
1187
+ elif isinstance (indices , int ):
1188
+ indices_tensor = paddle .to_tensor ([indices ], dtype = 'int64' )
1189
+ else :
1190
+ # Otherwise (e.g., list, tuple), convert directly
1191
+ indices_tensor = paddle .to_tensor (indices , dtype = 'int64' )
1192
+ # Ensure indices is a 1D tensor
1193
+ if indices_tensor .ndim == 0 :
1194
+ indices_tensor = indices_tensor .reshape ([1 ])
1195
+ elif indices_tensor .ndim > 1 :
1196
+ raise ValueError (f"indices must be a 1D tensor, but received a { indices_tensor .ndim } D tensor" )
1197
+
1198
+ return paddle .index_select (x , index = indices_tensor , axis = _axis )
1071
1199
1072
1200
1073
1201
def sign (x : array , / ) -> array :
1074
1202
# paddle sign() does not support complex numbers and does not propagate
1075
1203
# nans. See https://github.com/data-apis/array-api-compat/issues/136
1076
- if paddle .is_complex (x ):
1077
- out = x / paddle .abs (x )
1078
- # sign(0) = 0 but the above formula would give nan
1079
- out [x == 0 + 0j ] = 0 + 0j
1080
- return out
1204
+ if paddle .is_complex (x ) and x .ndim == 0 and x .item () == 0j :
1205
+ # Handle 0-D complex zero explicitly
1206
+ return paddle .zeros_like (x , dtype = x .dtype )
1081
1207
else :
1082
1208
out = paddle .sign (x )
1083
1209
if paddle .is_floating_point (x ):
0 commit comments