|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import builtins |
3 | 4 | from typing import Literal
|
4 | 5 | import numpy as np
|
5 | 6 |
|
@@ -112,25 +113,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
|
112 | 113 | raise TypeError("At least one array or dtype must be provided")
|
113 | 114 | if len(arrays_and_dtypes) == 1:
|
114 | 115 | x = arrays_and_dtypes[0]
|
115 |
| - if isinstance(x, paddle.dtype): |
116 |
| - return x |
117 |
| - return x.dtype |
| 116 | + return x if isinstance(x, paddle.dtype) else x.dtype |
118 | 117 | if len(arrays_and_dtypes) > 2:
|
119 | 118 | return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
|
120 | 119 |
|
121 | 120 | x, y = arrays_and_dtypes
|
122 |
| - xdt = x.dtype if not isinstance(x, paddle.dtype) else x |
123 |
| - ydt = y.dtype if not isinstance(y, paddle.dtype) else y |
| 121 | + xdt = x if isinstance(x, paddle.dtype) else x.dtype |
| 122 | + ydt = y if isinstance(y, paddle.dtype) else y.dtype |
124 | 123 |
|
125 | 124 | if (xdt, ydt) in _promotion_table:
|
126 |
| - return _promotion_table[xdt, ydt] |
127 |
| - |
128 |
| - # This doesn't result_type(dtype, dtype) for non-array API dtypes |
129 |
| - # because paddle.result_type only accepts tensors. This does however, allow |
130 |
| - # cross-kind promotion. |
131 |
| - x = paddle.to_tensor([], dtype=x) if isinstance(x, paddle.dtype) else x |
132 |
| - y = paddle.to_tensor([], dtype=y) if isinstance(y, paddle.dtype) else y |
133 |
| - return paddle.result_type(x, y) |
| 125 | + return _promotion_table[(xdt, ydt)] |
| 126 | + |
| 127 | + type_order = { |
| 128 | + paddle.bool: 0, |
| 129 | + paddle.int8: 1, |
| 130 | + paddle.uint8: 2, |
| 131 | + paddle.int16: 3, |
| 132 | + paddle.int32: 4, |
| 133 | + paddle.int64: 5, |
| 134 | + paddle.float16: 6, |
| 135 | + paddle.float32: 7, |
| 136 | + paddle.float64: 8, |
| 137 | + paddle.complex64: 9, |
| 138 | + paddle.complex128: 10 |
| 139 | + } |
| 140 | + |
| 141 | + return xdt if type_order.get(xdt, 0) > type_order.get(ydt, 0) else ydt |
134 | 142 |
|
135 | 143 |
|
136 | 144 | def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
|
@@ -922,7 +930,15 @@ def astype(
|
922 | 930 |
|
923 | 931 |
|
924 | 932 | def broadcast_arrays(*arrays: array) -> List[array]:
|
925 |
| - return paddle.broadcast_tensors(arrays) |
| 933 | + original_dtypes = [arr.dtype for arr in arrays] |
| 934 | + if len(set(original_dtypes)) == 1: |
| 935 | + return paddle.broadcast_tensors(arrays) |
| 936 | + target_dtype = result_type(*arrays) |
| 937 | + casted_arrays = [arr.astype(target_dtype) if arr.dtype != target_dtype else arr |
| 938 | + for arr in arrays] |
| 939 | + broadcasted = paddle.broadcast_tensors(casted_arrays) |
| 940 | + result = [arr.astype(original_dtype) for arr, original_dtype in zip(broadcasted, original_dtypes)] |
| 941 | + return result |
926 | 942 |
|
927 | 943 |
|
928 | 944 | # Note that these named tuples aren't actually part of the standard namespace,
|
|
0 commit comments