Skip to content

Commit 39d285c

Browse files
committed
tweak _result_type
1 parent 845d8b5 commit 39d285c

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,18 @@ def result_type(
151151
return _reduce(_result_type, others + scalars)
152152

153153

154-
def _result_type(x, y):
154+
def _result_type(
155+
x: Array | DType | bool | int | float | complex,
156+
y: Array | DType | bool | int | float | complex,
157+
) -> DType:
155158
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
156-
xdt = x.dtype if not isinstance(x, torch.dtype) else x
157-
ydt = y.dtype if not isinstance(y, torch.dtype) else y
159+
xdt = x if isinstance(x, torch.dtype) else x.dtype
160+
ydt = y if isinstance(y, torch.dtype) else y.dtype
158161

159-
if (xdt, ydt) in _promotion_table:
162+
try:
160163
return _promotion_table[xdt, ydt]
164+
except KeyError:
165+
pass
161166

162167
# This doesn't result_type(dtype, dtype) for non-array API dtypes
163168
# because torch.result_type only accepts tensors. This does however, allow

0 commit comments

Comments
 (0)