Skip to content

Commit 564d885

Browse files
committed
MAINT: result_type cosmetic refactor
1 parent ea5deb1 commit 564d885

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

array_api_strict/_data_type_functions.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -204,31 +204,28 @@ def result_type(
204204
# required by the spec rather than using np.result_type. NumPy implements
205205
# too many extra type promotions like int64 + uint64 -> float64, and does
206206
# value-based casting on scalar arrays.
207-
A = []
207+
dtypes = []
208208
scalars = []
209209
for a in arrays_and_dtypes:
210-
if isinstance(a, Array):
211-
a = a.dtype
210+
if isinstance(a, DType):
211+
dtypes.append(a)
212+
elif isinstance(a, Array):
213+
dtypes.append(a.dtype)
212214
elif isinstance(a, (bool, int, float, complex)):
213215
scalars.append(a)
214-
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
216+
else:
215217
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
216-
A.append(a)
217-
218-
# remove python scalars
219-
B = [a for a in A if not isinstance(a, (bool, int, float, complex))]
220218

221-
if len(B) == 0:
219+
if not dtypes:
222220
raise ValueError("at least one array or dtype is required")
223-
elif len(B) == 1:
224-
result = B[0]
221+
elif len(dtypes) == 1:
222+
result = dtypes[0]
225223
else:
226-
t = B[0]
227-
for t2 in B[1:]:
228-
t = _result_type(t, t2)
229-
result = t
224+
result = dtypes[0]
225+
for t2 in dtypes[1:]:
226+
result = _result_type(result, t2)
230227

231-
if len(scalars) == 0:
228+
if not scalars:
232229
return result
233230

234231
if get_array_api_strict_flags()['api_version'] <= '2023.12':

0 commit comments

Comments
 (0)