@@ -112,25 +112,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
112
112
raise TypeError ("At least one array or dtype must be provided" )
113
113
if len (arrays_and_dtypes ) == 1 :
114
114
x = arrays_and_dtypes [0 ]
115
- if isinstance (x , paddle .dtype ):
116
- return x
117
- return x .dtype
115
+ return x if isinstance (x , paddle .dtype ) else x .dtype
118
116
if len (arrays_and_dtypes ) > 2 :
119
117
return result_type (arrays_and_dtypes [0 ], result_type (* arrays_and_dtypes [1 :]))
120
118
121
119
x , y = arrays_and_dtypes
122
- xdt = x . dtype if not isinstance (x , paddle .dtype ) else x
123
- ydt = y . dtype if not isinstance (y , paddle .dtype ) else y
120
+ xdt = x if isinstance (x , paddle .dtype ) else x . dtype
121
+ ydt = y if isinstance (y , paddle .dtype ) else y . dtype
124
122
125
123
if (xdt , ydt ) in _promotion_table :
126
- return _promotion_table [xdt , ydt ]
127
-
128
- # This doesn't result_type(dtype, dtype) for non-array API dtypes
129
- # because paddle.result_type only accepts tensors. This does however, allow
130
- # cross-kind promotion.
131
- x = paddle .to_tensor ([], dtype = x ) if isinstance (x , paddle .dtype ) else x
132
- y = paddle .to_tensor ([], dtype = y ) if isinstance (y , paddle .dtype ) else y
133
- return paddle .result_type (x , y )
124
+ return _promotion_table [(xdt , ydt )]
125
+
126
+ type_order = {
127
+ paddle .bool : 0 ,
128
+ paddle .int8 : 1 ,
129
+ paddle .uint8 : 2 ,
130
+ paddle .int16 : 3 ,
131
+ paddle .int32 : 4 ,
132
+ paddle .int64 : 5 ,
133
+ paddle .float16 : 6 ,
134
+ paddle .float32 : 7 ,
135
+ paddle .float64 : 8 ,
136
+ paddle .complex64 : 9 ,
137
+ paddle .complex128 : 10
138
+ }
139
+
140
+ return xdt if type_order .get (xdt , 0 ) > type_order .get (ydt , 0 ) else ydt
134
141
135
142
136
143
def can_cast (from_ : Union [Dtype , array ], to : Dtype , / ) -> bool :
@@ -922,7 +929,15 @@ def astype(
922
929
923
930
924
931
def broadcast_arrays (* arrays : array ) -> List [array ]:
925
- return paddle .broadcast_tensors (arrays )
932
+ original_dtypes = [arr .dtype for arr in arrays ]
933
+ if len (set (original_dtypes )) == 1 :
934
+ return paddle .broadcast_tensors (arrays )
935
+ target_dtype = result_type (* arrays )
936
+ casted_arrays = [arr .astype (target_dtype ) if arr .dtype != target_dtype else arr
937
+ for arr in arrays ]
938
+ broadcasted = paddle .broadcast_tensors (casted_arrays )
939
+ result = [arr .astype (original_dtype ) for arr , original_dtype in zip (broadcasted , original_dtypes )]
940
+ return result
926
941
927
942
928
943
# Note that these named tuples aren't actually part of the standard namespace,
0 commit comments