Skip to content

Commit 602fdbd

Browse files
authored
Add uint64 to thunder->torch dtype map (#2519)
1 parent 5707666 commit 602fdbd

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

thunder/core/dtypes.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(self, name, shortname, *, bytes, is_weak):
239239

240240
boolean_dtypes = {bool8, bool8_, bool}
241241

242-
integer_dtypes = {d for d in all_dtypes if isinstance(d, exact)} | {bool, int}
242+
integer_dtypes = {d for d in all_dtypes if isinstance(d, exact)} | {bool, int} | {uint64, uint64_}
243243

244244
nonboolean_integer_dtypes = {d for d in integer_dtypes if (not isinstance(d, bool_) and d is not bool)}
245245

@@ -264,7 +264,7 @@ def __init__(self, name, shortname, *, bytes, is_weak):
264264

265265
weak_dtypes = {d for d in all_dtypes if d.is_weak} | all_numbertypes
266266

267-
strong_dtypes = {d for d in all_dtypes if not d.is_weak}
267+
strong_dtypes = {d for d in all_dtypes if not d.is_weak} | {uint64}
268268

269269

270270
def is_weak_dtype(dtype):
@@ -384,9 +384,11 @@ def corresponding_complex_dtype(dtype):
384384

385385

386386
_name_to_dtype_map = {dtype.full_name: dtype for dtype in all_dtypes}
387+
_name_to_dtype_map.update({uint64.full_name: uint64, uint64_.full_name: uint64_})
387388
_strong_dtype_to_weak_dtype_map = {
388389
dtype: _name_to_dtype_map[f"{dtype.full_name}_"] for dtype in all_dtypes if not dtype.is_weak
389390
}
391+
_strong_dtype_to_weak_dtype_map.update({uint64: uint64_})
390392

391393
_weak_dtype_to_strong_dtype_map = {v: k for k, v in _strong_dtype_to_weak_dtype_map.items()}
392394
_weak_dtype_to_strong_dtype_map.update(
@@ -395,6 +397,7 @@ def corresponding_complex_dtype(dtype):
395397
int: int64,
396398
float: float32,
397399
complex: complex64,
400+
uint64_: uint64,
398401
}
399402
)
400403

@@ -524,6 +527,12 @@ def are_same_dtypes(a, b, *, weak_and_strong_are_equivalent=True):
524527
if hasattr(torch, dtype.full_name.rstrip("_"))
525528
}
526529
)
530+
_thunder_to_torch_dtype_map.update(
531+
{
532+
uint64: torch.uint64,
533+
uint64_: torch.uint64,
534+
}
535+
)
527536

528537
_torch_to_thunder_dtype_map = {
529538
v: k

0 commit comments

Comments
 (0)