Skip to content

Commit 13e2782

Browse files
author
Hongyuhe
committed
update
1 parent add32c9 commit 13e2782

File tree

3 files changed

+137
-14
lines changed

3 files changed

+137
-14
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -496,33 +496,56 @@ def prod(
496496
x = paddle.to_tensor(x)
497497
ndim = x.ndim
498498

499+
# fix reducing on the zero dimension
500+
if x.numel() == 0:
501+
if dtype is not None:
502+
output_dtype = _NP_2_PADDLE_DTYPE[dtype.name]
503+
else:
504+
if x.dtype == paddle.bool:
505+
output_dtype = paddle.int64
506+
else:
507+
output_dtype = x.dtype
508+
509+
if axis is None:
510+
return paddle.to_tensor(1, dtype=output_dtype)
511+
512+
if keepdims:
513+
output_shape = list(x.shape)
514+
if isinstance(axis, int):
515+
axis = (axis,)
516+
for ax in axis:
517+
output_shape[ax] = 1
518+
else:
519+
output_shape = [dim for i, dim in enumerate(x.shape) if i not in (axis if isinstance(axis, tuple) else [axis])]
520+
if not output_shape:
521+
return paddle.to_tensor(1, dtype=output_dtype)
522+
523+
return paddle.ones(output_shape, dtype=output_dtype)
524+
525+
499526
if dtype is not None:
500-
# import pdb
501-
# pdb.set_trace()
502-
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
503-
# below because it still needs to upcast.
527+
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
528+
504529
if axis == ():
505530
if dtype is None:
506-
# We can't upcast uint8 according to the spec because there is no
507-
# paddle.uint64, so at least upcast to int64 which is what sum does
508-
# when axis=None.
509531
if x.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.uint8]:
510532
return x.to(paddle.int64)
511533
return x.clone()
512534
return x.to(dtype)
513535

514-
# paddle.prod doesn't support multiple axes
515536
if isinstance(axis, tuple):
516537
return _reduce_multiple_axes(
517538
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
518539
)
540+
519541
if axis is None:
520-
# paddle doesn't support keepdims with axis=None
542+
if dtype is None and x.dtype == paddle.int32:
543+
dtype = 'int64'
521544
res = paddle.prod(x, dtype=dtype, **kwargs)
522545
res = _axis_none_keepdims(res, ndim, keepdims)
523546
return res
524-
525-
return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
547+
548+
return paddle.prod(x, axis=axis, keepdims=keepdims, dtype=dtype, **kwargs)
526549

527550

528551
def sum(
@@ -771,7 +794,17 @@ def roll(
771794
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
772795
if x.ndim == 0:
773796
raise ValueError("nonzero() does not support zero-dimensional arrays")
774-
return paddle.nonzero(x, as_tuple=True, **kwargs)
797+
798+
if paddle.is_floating_point(x) or paddle.is_complex(x) :
799+
# Use paddle.isclose() to determine which elements are
800+
# "close enough" to zero.
801+
zero_tensor = paddle.zeros(shape=x.shape ,dtype=x.dtype)
802+
is_zero_mask = paddle.isclose(x, zero_tensor)
803+
is_nonzero_mask = paddle.logical_not(is_zero_mask)
804+
return paddle.nonzero(is_nonzero_mask, as_tuple=True, **kwargs)
805+
806+
else:
807+
return paddle.nonzero(x, as_tuple=True, **kwargs)
775808

776809

777810
def where(condition: array, x1: array, x2: array, /) -> array:
@@ -1003,6 +1036,22 @@ def unique_values(x: array) -> array:
10031036

10041037
def matmul(x1: array, x2: array, /, **kwargs) -> array:
10051038
# paddle.matmul doesn't type promote (but differently from _fix_promotion)
1039+
d1 = x1.ndim
1040+
d2 = x2.ndim
1041+
1042+
if d1 == 0 or d2 == 0:
1043+
raise ValueError("matmul does not support 0-D (scalar) inputs.")
1044+
1045+
k1 = x1.shape[-1]
1046+
1047+
if d2 == 1:
1048+
k2 = x2.shape[0]
1049+
else:
1050+
k2 = x2.shape[-2]
1051+
1052+
if k1 != k2:
1053+
raise ValueError(f"Shapes {x1.shape} and {x2.shape} are not aligned for matmul: "
1054+
f"{k1} (dim -1 of x1) != {k2} (dim -2 of x2)")
10061055
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
10071056
return paddle.matmul(x1, x2, **kwargs)
10081057

array_api_compat/paddle/linalg.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,24 @@ def matrix_norm(
125125
keepdims: bool = False,
126126
ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro",
127127
) -> array:
128-
return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
129-
128+
res = paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
129+
if res.dtype == paddle.complex64 :
130+
res = paddle.cast(res, "float32")
131+
if res.dtype == paddle.complex128:
132+
res = paddle.cast(res, "float64")
133+
return res
130134

131135
def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
132136
if rtol is None:
133137
return paddle.linalg.pinv(x)
138+
139+
# change rtol shape
140+
if isinstance(rtol, (int, float)):
141+
rtol = paddle.to_tensor(rtol, dtype=x.dtype)
142+
143+
# broadcast rtol to [..., 1]
144+
if rtol.ndim > 0:
145+
rtol = rtol.unsqueeze(-1)
134146

135147
return paddle.linalg.pinv(x, rcond=rtol)
136148

@@ -157,6 +169,9 @@ def svd(x: array, full_matrices: Optional[bool]= None) -> array:
157169
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh'])
158170
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])
159171

172+
def svdvals(x: array) -> array:
173+
return paddle.linalg.svd(x)[1]
174+
160175
__all__ = linalg_all + [
161176
"outer",
162177
"matmul",
@@ -171,6 +186,7 @@ def svd(x: array, full_matrices: Optional[bool]= None) -> array:
171186
"slogdet",
172187
"eigh",
173188
"diagonal",
189+
"svdvals"
174190
]
175191

176192
_all_ignore = ["paddle_linalg", "sum"]

paddle-xfails.txt

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,61 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
106106
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
107107
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
108108
array_api_tests/test_searching_functions.py::test_where
109+
110+
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
111+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
112+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
113+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
114+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
115+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
116+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
117+
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
118+
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
119+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
120+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
121+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
122+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
123+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
124+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
125+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
126+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
127+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
128+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
129+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
130+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
131+
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
132+
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
133+
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)]
134+
array_api_tests/test_operators_and_elementwise_functions.py::test_hypot
135+
array_api_tests/test_operators_and_elementwise_functions.py::test_copysign
136+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
137+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
138+
array_api_tests/test_linalg.py::test_outer
139+
array_api_tests/test_linalg.py::test_vecdot
140+
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
141+
array_api_tests/test_manipulation_functions.py::test_stack
142+
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide
143+
144+
# do not pass
145+
array_api_tests/test_has_names[array_attribute-device]
146+
array_api_tests/test_signatures.py::test_func_signature[meshgrid]
147+
148+
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
149+
array_api_tests/test_indexing_functions.py::test_take
150+
array_api_tests/test_linalg.py::test_linalg_vecdot
151+
array_api_tests/test_creation_functions.py::test_asarray_arrays
152+
153+
array_api_tests/test_linalg.py::test_qr
154+
155+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift
156+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift
157+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and
158+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or
159+
160+
# test exceeds the deadline of 800ms
161+
array_api_tests/test_linalg.py::test_pinv
162+
array_api_tests/test_linalg.py::test_det
163+
164+
# only supports access to dimension 0 to 9, but received dimension is 10.
165+
array_api_tests/test_linalg.py::test_tensordot
166+
array_api_tests/test_linalg.py::test_linalg_tensordot

0 commit comments

Comments
 (0)