@@ -194,7 +194,7 @@ def isdtype(
194
194
else :
195
195
raise TypeError (f"'kind' must be a dtype, str, or tuple of dtypes and strs, not { type (kind ).__name__ } " )
196
196
197
- def result_type (* arrays_and_dtypes : Union [Array , Dtype ]) -> Dtype :
197
+ def result_type (* arrays_and_dtypes : Union [Array , Dtype , int , float , complex , bool ]) -> Dtype :
198
198
"""
199
199
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
200
200
@@ -205,19 +205,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
205
205
# too many extra type promotions like int64 + uint64 -> float64, and does
206
206
# value-based casting on scalar arrays.
207
207
A = []
208
+ scalars = []
208
209
for a in arrays_and_dtypes :
209
210
if isinstance (a , Array ):
210
211
a = a .dtype
212
+ elif isinstance (a , (bool , int , float , complex )):
213
+ scalars .append (a )
211
214
elif isinstance (a , np .ndarray ) or a not in _all_dtypes :
212
215
raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
213
216
A .append (a )
214
217
218
+ # remove python scalars
219
+ A = [a for a in A if not isinstance (a , (bool , int , float , complex ))]
220
+
215
221
if len (A ) == 0 :
216
222
raise ValueError ("at least one array or dtype is required" )
217
223
elif len (A ) == 1 :
218
- return A [0 ]
224
+ result = A [0 ]
219
225
else :
220
226
t = A [0 ]
221
227
for t2 in A [1 :]:
222
228
t = _result_type (t , t2 )
223
- return t
229
+ result = t
230
+
231
+ if len (scalars ) == 0 :
232
+ return result
233
+
234
+ if get_array_api_strict_flags ()['api_version' ] <= '2023.12' :
235
+ raise TypeError ("result_type() inputs must be array_api arrays or dtypes" )
236
+
237
+ # promote python scalars given the result_type for all arrays/dtypes
238
+ from ._creation_functions import empty
239
+ arr = empty (1 , dtype = result )
240
+ for s in scalars :
241
+ x = arr ._promote_scalar (s )
242
+ result = _result_type (x .dtype , result )
243
+
244
+ return result
0 commit comments