@@ -60,21 +60,33 @@ def externalize_module_parameters(
6060 torch .complex128 : "complex128" ,
6161 torch .float16 : "float16" ,
6262 torch .bfloat16 : "bfloat16" ,
63- torch .float8_e4m3fn : "float8_e4m3fn" ,
64- torch .float8_e4m3fnuz : "float8_e4m3fnuz" ,
65- torch .float8_e5m2 : "float8_e5m2" ,
66- torch .float8_e5m2fnuz : "float8_e5m2fnuz" ,
6763 torch .int8 : "int8" ,
6864 torch .int16 : "int16" ,
6965 torch .int32 : "int32" ,
7066 torch .int64 : "int64" ,
71- torch .uint16 : "uint16" ,
72- torch .uint32 : "uint32" ,
73- torch .uint64 : "uint64" ,
7467 torch .uint8 : "uint8" ,
7568 torch .bool : "bool" ,
7669}
7770
71+
72+ # Deal with datatypes not yet added in all versions of Torch.
73+ def _add_optional_dtype (name : str ):
74+ try :
75+ dtype = getattr (torch , name )
76+ except AttributeError :
77+ return
78+ _dtype_to_name [dtype ] = name
79+
80+
81+ _add_optional_dtype ("float8_e4m3fn" )
82+ _add_optional_dtype ("float8_e4m3fnuz" )
83+ _add_optional_dtype ("float8_e5m2" )
84+ _add_optional_dtype ("float8_e5m2fnuz" )
85+ _add_optional_dtype ("uint16" )
86+ _add_optional_dtype ("uint32" )
87+ _add_optional_dtype ("uint64" )
88+
89+
7890_name_to_dtype : dict [str , torch .dtype ] = {v : k for k , v in _dtype_to_name .items ()}
7991
8092_metadata_prefix = "PYTORCH:"
0 commit comments