Skip to content

Commit 41a93be

Browse files
Merge pull request #2 from aquagull/support_paddle
Support paddle
2 parents 912fe3e + 67aa9ef commit 41a93be

File tree

3 files changed

+241
-23
lines changed

3 files changed

+241
-23
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 142 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
462462
out = paddle.unsqueeze(out, a)
463463
return out
464464

465+
_NP_2_PADDLE_DTYPE = {
466+
"BOOL": 'bool',
467+
"UINT8": 'uint8',
468+
"INT8": 'int8',
469+
"INT16": 'int16',
470+
"INT32": 'int32',
471+
"INT64": 'int64',
472+
"FLOAT16": 'float16',
473+
"BFLOAT16": 'bfloat16',
474+
"FLOAT32": 'float32',
475+
"FLOAT64": 'float64',
476+
"COMPLEX128": 'complex128',
477+
"COMPLEX64": 'complex64',
478+
}
465479

466480
def prod(
467481
x: array,
@@ -476,7 +490,36 @@ def prod(
476490
x = paddle.to_tensor(x)
477491
ndim = x.ndim
478492

479-
# below because it still needs to upcast.
493+
# fix reducing on the zero dimension
494+
if x.numel() == 0:
495+
if dtype is not None:
496+
output_dtype = _NP_2_PADDLE_DTYPE[dtype.name]
497+
else:
498+
if x.dtype == paddle.bool:
499+
output_dtype = paddle.int64
500+
else:
501+
output_dtype = x.dtype
502+
503+
if axis is None:
504+
return paddle.to_tensor(1, dtype=output_dtype)
505+
506+
if keepdims:
507+
output_shape = list(x.shape)
508+
if isinstance(axis, int):
509+
axis = (axis,)
510+
for ax in axis:
511+
output_shape[ax] = 1
512+
else:
513+
output_shape = [dim for i, dim in enumerate(x.shape) if i not in (axis if isinstance(axis, tuple) else [axis])]
514+
if not output_shape:
515+
return paddle.to_tensor(1, dtype=output_dtype)
516+
517+
return paddle.ones(output_shape, dtype=output_dtype)
518+
519+
520+
if dtype is not None:
521+
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
522+
480523
if axis == ():
481524
if dtype is None:
482525
# We can't upcast uint8 according to the spec because there is no
@@ -492,13 +535,17 @@ def prod(
492535
return _reduce_multiple_axes(
493536
paddle.prod, x, axis, keepdim=keepdims, dtype=dtype, **kwargs
494537
)
538+
539+
495540
if axis is None:
496541
# paddle doesn't support keepdims with axis=None
542+
if dtype is None and x.dtype == paddle.int32:
543+
dtype = 'int64'
497544
res = paddle.prod(x, dtype=dtype, **kwargs)
498545
res = _axis_none_keepdims(res, ndim, keepdims)
499546
return res
500-
501-
return paddle.prod(x, axis, dtype=dtype, keepdim=keepdims, **kwargs)
547+
548+
return paddle.prod(x, axis=axis, keepdims=keepdims, dtype=dtype, **kwargs)
502549

503550

504551
def sum(
@@ -747,7 +794,17 @@ def roll(
747794
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
748795
if x.ndim == 0:
749796
raise ValueError("nonzero() does not support zero-dimensional arrays")
750-
return paddle.nonzero(x, as_tuple=True, **kwargs)
797+
798+
if paddle.is_floating_point(x) or paddle.is_complex(x) :
799+
# Use paddle.isclose() to determine which elements are
800+
# "close enough" to zero.
801+
zero_tensor = paddle.zeros(shape=x.shape ,dtype=x.dtype)
802+
is_zero_mask = paddle.isclose(x, zero_tensor)
803+
is_nonzero_mask = paddle.logical_not(is_zero_mask)
804+
return paddle.nonzero(is_nonzero_mask, as_tuple=True, **kwargs)
805+
806+
else:
807+
return paddle.nonzero(x, as_tuple=True, **kwargs)
751808

752809

753810
def where(condition: array, x1: array, x2: array, /) -> array:
@@ -832,7 +889,7 @@ def eye(
832889
if n_cols is None:
833890
n_cols = n_rows
834891
z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device)
835-
if abs(k) <= n_rows + n_cols:
892+
if n_rows > 0 and n_cols > 0 and abs(k) <= n_rows + n_cols:
836893
z.diagonal(k).fill_(1)
837894
return z
838895

@@ -867,7 +924,11 @@ def full(
867924
) -> array:
868925
if isinstance(shape, int):
869926
shape = (shape,)
870-
927+
if dtype is None :
928+
if isinstance(fill_value, (bool)):
929+
dtype = "bool"
930+
elif isinstance(fill_value, int):
931+
dtype = 'int64'
871932
return paddle.full(shape, fill_value, dtype=dtype, **kwargs).to(device)
872933

873934

@@ -914,6 +975,8 @@ def triu(x: array, /, *, k: int = 0) -> array:
914975

915976

916977
def expand_dims(x: array, /, *, axis: int = 0) -> array:
978+
if axis < -x.ndim - 1 or axis > x.ndim:
979+
raise IndexError(f"Axis {axis} is out of bounds for array of dimension { x.ndim}")
917980
return paddle.unsqueeze(x, axis)
918981

919982

@@ -973,6 +1036,22 @@ def unique_values(x: array) -> array:
9731036

9741037
def matmul(x1: array, x2: array, /, **kwargs) -> array:
9751038
# paddle.matmul doesn't type promote (but differently from _fix_promotion)
1039+
d1 = x1.ndim
1040+
d2 = x2.ndim
1041+
1042+
if d1 == 0 or d2 == 0:
1043+
raise ValueError("matmul does not support 0-D (scalar) inputs.")
1044+
1045+
k1 = x1.shape[-1]
1046+
1047+
if d2 == 1:
1048+
k2 = x2.shape[0]
1049+
else:
1050+
k2 = x2.shape[-2]
1051+
1052+
if k1 != k2:
1053+
raise ValueError(f"Shapes {x1.shape} and {x2.shape} are not aligned for matmul: "
1054+
f"{k1} (dim -1 of x1) != {k2} (dim -2 of x2)")
9761055
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
9771056
return paddle.matmul(x1, x2, **kwargs)
9781057

@@ -988,7 +1067,36 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:
9881067

9891068

9901069
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
991-
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
1070+
shape1 = x1.shape
1071+
shape2 = x2.shape
1072+
rank1 = len(shape1)
1073+
rank2 = len(shape2)
1074+
if rank1 == 0 or rank2 == 0:
1075+
raise ValueError(
1076+
f"Vector dot product requires non-scalar inputs (rank > 0). "
1077+
f"Got ranks {rank1} and {rank2} for shapes {shape1} and {shape2}."
1078+
)
1079+
try:
1080+
norm_axis1 = axis if axis >= 0 else rank1 + axis
1081+
if not (0 <= norm_axis1 < rank1):
1082+
raise IndexError # Axis out of bounds for x1
1083+
norm_axis2 = axis if axis >= 0 else rank2 + axis
1084+
if not (0 <= norm_axis2 < rank2):
1085+
raise IndexError # Axis out of bounds for x2
1086+
size1 = shape1[norm_axis1]
1087+
size2 = shape2[norm_axis2]
1088+
except IndexError:
1089+
raise ValueError(
1090+
f"Axis {axis} is out of bounds for input shapes {shape1} (rank {rank1}) "
1091+
f"and/or {shape2} (rank {rank2})."
1092+
)
1093+
1094+
if size1 != size2:
1095+
raise ValueError(
1096+
f"Inputs must have the same dimension size along the reduction axis ({axis}). "
1097+
f"Got shapes {shape1} and {shape2}, with sizes {size1} and {size2} "
1098+
f"along the normalized axis {norm_axis1} and {norm_axis2} respectively."
1099+
)
9921100
return paddle.linalg.vecdot(x1, x2, axis=axis)
9931101

9941102

@@ -1063,21 +1171,39 @@ def is_complex(dtype):
10631171

10641172

10651173
def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
1066-
if axis is None:
1174+
_axis = axis
1175+
if _axis is None:
10671176
if x.ndim != 1:
1068-
raise ValueError("axis must be specified when ndim > 1")
1069-
axis = 0
1070-
return paddle.index_select(x, axis, indices, **kwargs)
1177+
raise ValueError("axis must be specified when x.ndim > 1")
1178+
_axis = 0
1179+
elif not isinstance(_axis, int):
1180+
raise TypeError(f"axis must be an integer, but received {type(_axis)}")
1181+
1182+
if not (-x.ndim <= _axis < x.ndim):
1183+
raise IndexError(f"axis {_axis} is out of bounds for tensor of dimension {x.ndim}")
1184+
1185+
if isinstance(indices, paddle.Tensor):
1186+
indices_tensor = indices
1187+
elif isinstance(indices, int):
1188+
indices_tensor = paddle.to_tensor([indices], dtype='int64')
1189+
else:
1190+
# Otherwise (e.g., list, tuple), convert directly
1191+
indices_tensor = paddle.to_tensor(indices, dtype='int64')
1192+
# Ensure indices is a 1D tensor
1193+
if indices_tensor.ndim == 0:
1194+
indices_tensor = indices_tensor.reshape([1])
1195+
elif indices_tensor.ndim > 1:
1196+
raise ValueError(f"indices must be a 1D tensor, but received a {indices_tensor.ndim}D tensor")
1197+
1198+
return paddle.index_select(x, index=indices_tensor, axis=_axis)
10711199

10721200

10731201
def sign(x: array, /) -> array:
10741202
# paddle sign() does not support complex numbers and does not propagate
10751203
# nans. See https://github.com/data-apis/array-api-compat/issues/136
1076-
if paddle.is_complex(x):
1077-
out = x / paddle.abs(x)
1078-
# sign(0) = 0 but the above formula would give nan
1079-
out[x == 0 + 0j] = 0 + 0j
1080-
return out
1204+
if paddle.is_complex(x) and x.ndim == 0 and x.item() == 0j:
1205+
# Handle 0-D complex zero explicitly
1206+
return paddle.zeros_like(x, dtype=x.dtype)
10811207
else:
10821208
out = paddle.sign(x)
10831209
if paddle.is_floating_point(x):

array_api_compat/paddle/linalg.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
8484
# Use our wrapped sum to make sure it does upcasting correctly
8585
return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype)
8686

87+
def diagonal(x: ndarray, / , *, offset: int = 0, **kwargs) -> ndarray:
88+
return paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
8789

8890
def vector_norm(
8991
x: array,
@@ -123,24 +125,52 @@ def matrix_norm(
123125
keepdims: bool = False,
124126
ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = "fro",
125127
) -> array:
126-
return paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
127-
128+
res = paddle.linalg.matrix_norm(x, p=ord, axis=(-2, -1), keepdim=keepdims)
129+
if res.dtype == paddle.complex64 :
130+
res = paddle.cast(res, "float32")
131+
if res.dtype == paddle.complex128:
132+
res = paddle.cast(res, "float64")
133+
return res
128134

129135
def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
130136
if rtol is None:
131137
return paddle.linalg.pinv(x)
138+
139+
# change rtol shape
140+
if isinstance(rtol, (int, float)):
141+
rtol = paddle.to_tensor(rtol, dtype=x.dtype)
142+
143+
# broadcast rtol to [..., 1]
144+
if rtol.ndim > 0:
145+
rtol = rtol.unsqueeze(-1)
132146

133147
return paddle.linalg.pinv(x, rcond=rtol)
134148

135149

136150
def slogdet(x: array):
137-
det = paddle.linalg.det(x)
138-
sign = paddle.sign(det)
139-
log_det = paddle.log(det)
151+
return tuple_to_namedtuple(paddle.linalg.slogdet(x), ["sign", "logabsdet"])
152+
153+
def tuple_to_namedtuple(data, fields):
154+
nt_class = namedtuple('DynamicNameTuple', fields)
155+
return nt_class(*data)
156+
157+
def eigh(x: array):
158+
return tuple_to_namedtuple(paddle.linalg.eigh(x), ['eigenvalues', 'eigenvectors'])
159+
160+
def qr(x: array, mode: Optional[str] = None) -> array:
161+
if mode is None:
162+
return tuple_to_namedtuple(paddle.linalg.qr(x), ['Q', 'R'])
163+
164+
return tuple_to_namedtuple(paddle.linalg.qr(x, mode), ['Q', 'R'])
165+
140166

141-
slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
142-
return slotdet(sign, log_det)
167+
def svd(x: array, full_matrices: Optional[bool]= None) -> array:
168+
if full_matrices is None :
169+
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices=True), ['U', 'S', 'Vh'])
170+
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])
143171

172+
def svdvals(x: array) -> array:
173+
return paddle.linalg.svd(x)[1]
144174

145175
__all__ = linalg_all + [
146176
"outer",
@@ -154,6 +184,9 @@ def slogdet(x: array):
154184
"trace",
155185
"vector_norm",
156186
"slogdet",
187+
"eigh",
188+
"diagonal",
189+
"svdvals"
157190
]
158191

159192
_all_ignore = ["paddle_linalg", "sum"]

paddle-xfails.txt

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,62 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
106106
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
107107
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
108108
array_api_tests/test_searching_functions.py::test_where
109+
110+
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)]
111+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
112+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x1, x2)]
113+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x1, x2)]
114+
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
115+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)]
116+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
117+
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)]
118+
array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
119+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]
120+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)]
121+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x1, x2)]
122+
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
123+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
124+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
125+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
126+
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x1, x2)]
127+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
128+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)]
129+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x1, x2)]
130+
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
131+
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x1, x2)]
132+
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
133+
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[pow(x1, x2)]
134+
array_api_tests/test_operators_and_elementwise_functions.py::test_hypot
135+
array_api_tests/test_operators_and_elementwise_functions.py::test_copysign
136+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
137+
array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)]
138+
array_api_tests/test_linalg.py::test_outer
139+
array_api_tests/test_linalg.py::test_vecdot
140+
array_api_tests/test_operators_and_elementwise_functions.py::test_clip
141+
array_api_tests/test_manipulation_functions.py::test_stack
142+
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide
143+
144+
# do not pass
145+
array_api_tests/test_has_names[array_attribute-device]
146+
array_api_tests/test_signatures.py::test_func_signature[meshgrid]
147+
148+
array_api_tests/test_operators_and_elementwise_functions.py::test_equal[equal(x1, x2)]
149+
array_api_tests/test_indexing_functions.py::test_take
150+
array_api_tests/test_linalg.py::test_linalg_vecdot
151+
array_api_tests/test_creation_functions.py::test_asarray_arrays
152+
153+
array_api_tests/test_linalg.py::test_qr
154+
155+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift
156+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift
157+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and
158+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or
159+
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor
160+
161+
# test exceeds the deadline of 800ms
162+
array_api_tests/test_linalg.py::test_pinv
163+
array_api_tests/test_linalg.py::test_det
164+
165+
# only supports access to dimension 0 to 9, but received dimension is 10.
166+
array_api_tests/test_linalg.py::test_tensordot
167+
array_api_tests/test_linalg.py::test_linalg_tensordot

0 commit comments

Comments
 (0)