Skip to content

Commit 4e68c66

Browse files
committed
ENH: allow python scalars as inputs to result_type
1 parent d6b87c1 commit 4e68c66

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

array_api_strict/_data_type_functions.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def isdtype(
194194
else:
195195
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")
196196

197-
def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
197+
def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype:
198198
"""
199199
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
200200
@@ -205,19 +205,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
205205
# too many extra type promotions like int64 + uint64 -> float64, and does
206206
# value-based casting on scalar arrays.
207207
A = []
208+
scalars = []
208209
for a in arrays_and_dtypes:
209210
if isinstance(a, Array):
210211
a = a.dtype
212+
elif isinstance(a, (bool, int, float, complex)):
213+
scalars.append(a)
211214
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
212215
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
213216
A.append(a)
214217

218+
# remove python scalars
219+
A = [a for a in A if not isinstance(a, (bool, int, float, complex))]
220+
215221
if len(A) == 0:
216222
raise ValueError("at least one array or dtype is required")
217223
elif len(A) == 1:
218-
return A[0]
224+
result = A[0]
219225
else:
220226
t = A[0]
221227
for t2 in A[1:]:
222228
t = _result_type(t, t2)
223-
return t
229+
result = t
230+
231+
if len(scalars) == 0:
232+
return result
233+
234+
if get_array_api_strict_flags()['api_version'] <= '2023.12':
235+
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
236+
237+
# promote python scalars given the result_type for all arrays/dtypes
238+
from ._creation_functions import empty
239+
arr = empty(1, dtype=result)
240+
for s in scalars:
241+
x = arr._promote_scalar(s)
242+
result = _result_type(x.dtype, result)
243+
244+
return result

0 commit comments

Comments
 (0)