Skip to content

Commit 37785d4

Browse files
committed
Add broadcast_tensors alias, modify result_type
1 parent fd6eea0 commit 37785d4

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import builtins
34
from typing import Literal
45
import numpy as np
56

@@ -112,25 +113,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
112113
raise TypeError("At least one array or dtype must be provided")
113114
if len(arrays_and_dtypes) == 1:
114115
x = arrays_and_dtypes[0]
115-
if isinstance(x, paddle.dtype):
116-
return x
117-
return x.dtype
116+
return x if isinstance(x, paddle.dtype) else x.dtype
118117
if len(arrays_and_dtypes) > 2:
119118
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
120119

121120
x, y = arrays_and_dtypes
122-
xdt = x.dtype if not isinstance(x, paddle.dtype) else x
123-
ydt = y.dtype if not isinstance(y, paddle.dtype) else y
121+
xdt = x if isinstance(x, paddle.dtype) else x.dtype
122+
ydt = y if isinstance(y, paddle.dtype) else y.dtype
124123

125124
if (xdt, ydt) in _promotion_table:
126-
return _promotion_table[xdt, ydt]
127-
128-
# This doesn't result_type(dtype, dtype) for non-array API dtypes
129-
# because paddle.result_type only accepts tensors. This does however, allow
130-
# cross-kind promotion.
131-
x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x
132-
y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y
133-
return paddle.result_type(x, y)
125+
return _promotion_table[(xdt, ydt)]
126+
127+
type_order = {
128+
paddle.bool: 0,
129+
paddle.int8: 1,
130+
paddle.uint8: 2,
131+
paddle.int16: 3,
132+
paddle.int32: 4,
133+
paddle.int64: 5,
134+
paddle.float16: 6,
135+
paddle.float32: 7,
136+
paddle.float64: 8,
137+
paddle.complex64: 9,
138+
paddle.complex128: 10
139+
}
140+
141+
return xdt if type_order.get(xdt, 0) > type_order.get(ydt, 0) else ydt
134142

135143

136144
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
@@ -922,7 +930,15 @@ def astype(
922930

923931

924932
def broadcast_arrays(*arrays: array) -> List[array]:
925-
return paddle.broadcast_tensors(arrays)
933+
original_dtypes = [arr.dtype for arr in arrays]
934+
if len(set(original_dtypes)) == 1:
935+
return paddle.broadcast_tensors(arrays)
936+
target_dtype = result_type(*arrays)
937+
casted_arrays = [arr.astype(target_dtype) if arr.dtype != target_dtype else arr
938+
for arr in arrays]
939+
broadcasted = paddle.broadcast_tensors(casted_arrays)
940+
result = [arr.astype(original_dtype) for arr, original_dtype in zip(broadcasted, original_dtypes)]
941+
return result
926942

927943

928944
# Note that these named tuples aren't actually part of the standard namespace,

0 commit comments

Comments
 (0)