|
22 | 22 | try:
|
23 | 23 | # torch >=2.3
|
24 | 24 | _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
|
25 |
| - _HAS_LARGE_UINT = True |
26 | 25 | except AttributeError:
|
27 |
| - _HAS_LARGE_UINT = False |
| 26 | + pass |
| 27 | + |
28 | 28 |
|
29 | 29 | _array_api_dtypes = {
|
30 | 30 | torch.bool,
|
|
59 | 59 | (torch.float64, torch.complex128): torch.complex128,
|
60 | 60 | }
|
61 | 61 |
|
62 |
| -if _HAS_LARGE_UINT: # torch >=2.3 |
63 |
| - _promotion_table.update( |
64 |
| - { |
65 |
| - # uints |
66 |
| - (torch.uint8, torch.uint16): torch.uint16, |
67 |
| - (torch.uint8, torch.uint32): torch.uint32, |
68 |
| - (torch.uint8, torch.uint64): torch.uint64, |
69 |
| - (torch.uint16, torch.uint32): torch.uint32, |
70 |
| - (torch.uint16, torch.uint64): torch.uint64, |
71 |
| - (torch.uint32, torch.uint64): torch.uint64, |
72 |
| - # ints and uints (mixed sign) |
73 |
| - (torch.uint16, torch.int8): torch.int32, |
74 |
| - (torch.uint16, torch.int16): torch.int32, |
75 |
| - (torch.uint16, torch.int32): torch.int32, |
76 |
| - (torch.uint16, torch.int64): torch.int64, |
77 |
| - (torch.uint32, torch.int8): torch.int64, |
78 |
| - (torch.uint32, torch.int16): torch.int64, |
79 |
| - (torch.uint32, torch.int32): torch.int64, |
80 |
| - (torch.uint32, torch.int64): torch.int64, |
81 |
| - } |
82 |
| - ) |
83 |
| - |
84 | 62 | _promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
|
85 | 63 | _promotion_table.update({(a, a): a for a in _array_api_dtypes})
|
86 | 64 |
|
@@ -317,16 +295,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
|
317 | 295 | if dtype is not None:
|
318 | 296 | return x.clone() if dtype == x.dtype else x.to(dtype)
|
319 | 297 |
|
320 |
| - if x.dtype in (torch.int8, torch.int16, torch.int32): |
321 |
| - return x.to(torch.int64) |
322 |
| - |
323 |
| - if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32): |
324 |
| - return x.to(torch.uint64) |
325 |
| - |
326 |
| - if x.dtype == torch.uint8: |
327 |
| - # We can't upcast uint8 according to the spec because there is no |
328 |
| - # torch.uint64, so at least upcast to int64 which is what prod does |
329 |
| - # when axis=None. |
| 298 | + # We can't upcast uint8 according to the spec because there is no |
| 299 | + # torch.uint64, so at least upcast to int64 which is what prod does |
| 300 | + # when axis=None. |
| 301 | + if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): |
330 | 302 | return x.to(torch.int64)
|
331 | 303 |
|
332 | 304 | return x.clone()
|
|
0 commit comments