Skip to content

Commit 2d4e571

Browse files
Merge pull request #1 from cangtianhuang/support_paddle
Add broadcast_tensors alias, modify result_type
2 parents fd6eea0 + 0651731 commit 2d4e571

File tree

1 file changed

+29
-14
lines changed

1 file changed

+29
-14
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
112112
raise TypeError("At least one array or dtype must be provided")
113113
if len(arrays_and_dtypes) == 1:
114114
x = arrays_and_dtypes[0]
115-
if isinstance(x, paddle.dtype):
116-
return x
117-
return x.dtype
115+
return x if isinstance(x, paddle.dtype) else x.dtype
118116
if len(arrays_and_dtypes) > 2:
119117
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
120118

121119
x, y = arrays_and_dtypes
122-
xdt = x.dtype if not isinstance(x, paddle.dtype) else x
123-
ydt = y.dtype if not isinstance(y, paddle.dtype) else y
120+
xdt = x if isinstance(x, paddle.dtype) else x.dtype
121+
ydt = y if isinstance(y, paddle.dtype) else y.dtype
124122

125123
if (xdt, ydt) in _promotion_table:
126-
return _promotion_table[xdt, ydt]
127-
128-
# This doesn't result_type(dtype, dtype) for non-array API dtypes
129-
# because paddle.result_type only accepts tensors. This does however, allow
130-
# cross-kind promotion.
131-
x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x
132-
y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y
133-
return paddle.result_type(x, y)
124+
return _promotion_table[(xdt, ydt)]
125+
126+
type_order = {
127+
paddle.bool: 0,
128+
paddle.int8: 1,
129+
paddle.uint8: 2,
130+
paddle.int16: 3,
131+
paddle.int32: 4,
132+
paddle.int64: 5,
133+
paddle.float16: 6,
134+
paddle.float32: 7,
135+
paddle.float64: 8,
136+
paddle.complex64: 9,
137+
paddle.complex128: 10
138+
}
139+
140+
return xdt if type_order.get(xdt, 0) > type_order.get(ydt, 0) else ydt
134141

135142

136143
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
@@ -922,7 +929,15 @@ def astype(
922929

923930

924931
def broadcast_arrays(*arrays: array) -> List[array]:
925-
return paddle.broadcast_tensors(arrays)
932+
original_dtypes = [arr.dtype for arr in arrays]
933+
if len(set(original_dtypes)) == 1:
934+
return paddle.broadcast_tensors(arrays)
935+
target_dtype = result_type(*arrays)
936+
casted_arrays = [arr.astype(target_dtype) if arr.dtype != target_dtype else arr
937+
for arr in arrays]
938+
broadcasted = paddle.broadcast_tensors(casted_arrays)
939+
result = [arr.astype(original_dtype) for arr, original_dtype in zip(broadcasted, original_dtypes)]
940+
return result
926941

927942

928943
# Note that these named tuples aren't actually part of the standard namespace,

0 commit comments

Comments
 (0)