Skip to content

Commit 1ae5c1b

Browse files
committed
Revert uint promotions
1 parent 39d285c commit 1ae5c1b

File tree

1 file changed

+6
-34
lines changed

1 file changed

+6
-34
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
try:
2323
# torch >=2.3
2424
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
25-
_HAS_LARGE_UINT = True
2625
except AttributeError:
27-
_HAS_LARGE_UINT = False
26+
pass
27+
2828

2929
_array_api_dtypes = {
3030
torch.bool,
@@ -59,28 +59,6 @@
5959
(torch.float64, torch.complex128): torch.complex128,
6060
}
6161

62-
if _HAS_LARGE_UINT: # torch >=2.3
63-
_promotion_table.update(
64-
{
65-
# uints
66-
(torch.uint8, torch.uint16): torch.uint16,
67-
(torch.uint8, torch.uint32): torch.uint32,
68-
(torch.uint8, torch.uint64): torch.uint64,
69-
(torch.uint16, torch.uint32): torch.uint32,
70-
(torch.uint16, torch.uint64): torch.uint64,
71-
(torch.uint32, torch.uint64): torch.uint64,
72-
# ints and uints (mixed sign)
73-
(torch.uint16, torch.int8): torch.int32,
74-
(torch.uint16, torch.int16): torch.int32,
75-
(torch.uint16, torch.int32): torch.int32,
76-
(torch.uint16, torch.int64): torch.int64,
77-
(torch.uint32, torch.int8): torch.int64,
78-
(torch.uint32, torch.int16): torch.int64,
79-
(torch.uint32, torch.int32): torch.int64,
80-
(torch.uint32, torch.int64): torch.int64,
81-
}
82-
)
83-
8462
_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
8563
_promotion_table.update({(a, a): a for a in _array_api_dtypes})
8664

@@ -317,16 +295,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
317295
if dtype is not None:
318296
return x.clone() if dtype == x.dtype else x.to(dtype)
319297

320-
if x.dtype in (torch.int8, torch.int16, torch.int32):
321-
return x.to(torch.int64)
322-
323-
if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32):
324-
return x.to(torch.uint64)
325-
326-
if x.dtype == torch.uint8:
327-
# We can't upcast uint8 according to the spec because there is no
328-
# torch.uint64, so at least upcast to int64 which is what prod does
329-
# when axis=None.
298+
# We can't upcast uint8 according to the spec because there is no
299+
# torch.uint64, so at least upcast to int64 which is what prod does
300+
# when axis=None.
301+
if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
330302
return x.to(torch.int64)
331303

332304
return x.clone()

0 commit comments

Comments
 (0)