Skip to content

Commit 424a99b

Browse files
Make new datatypes optional (for compatibility with older torch versions) (#645)
1 parent 1ddecbb commit 424a99b

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

core/shark_turbine/aot/params.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,33 @@ def externalize_module_parameters(
6060
torch.complex128: "complex128",
6161
torch.float16: "float16",
6262
torch.bfloat16: "bfloat16",
63-
torch.float8_e4m3fn: "float8_e4m3fn",
64-
torch.float8_e4m3fnuz: "float8_e4m3fnuz",
65-
torch.float8_e5m2: "float8_e5m2",
66-
torch.float8_e5m2fnuz: "float8_e5m2fnuz",
6763
torch.int8: "int8",
6864
torch.int16: "int16",
6965
torch.int32: "int32",
7066
torch.int64: "int64",
71-
torch.uint16: "uint16",
72-
torch.uint32: "uint32",
73-
torch.uint64: "uint64",
7467
torch.uint8: "uint8",
7568
torch.bool: "bool",
7669
}
7770

71+
72+
# Deal with datatypes not yet added in all versions of Torch.
73+
def _add_optional_dtype(name: str):
74+
try:
75+
dtype = getattr(torch, name)
76+
except AttributeError:
77+
return
78+
_dtype_to_name[dtype] = name
79+
80+
81+
_add_optional_dtype("float8_e4m3fn")
82+
_add_optional_dtype("float8_e4m3fnuz")
83+
_add_optional_dtype("float8_e5m2")
84+
_add_optional_dtype("float8_e5m2fnuz")
85+
_add_optional_dtype("uint16")
86+
_add_optional_dtype("uint32")
87+
_add_optional_dtype("uint64")
88+
89+
7890
_name_to_dtype: dict[str, torch.dtype] = {v: k for k, v in _dtype_to_name.items()}
7991

8092
_metadata_prefix = "PYTORCH:"

0 commit comments

Comments
 (0)