diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index fb113ee835..ddb0e80309 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -199,7 +199,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) ) if dtype.itemsize == 1 and array.dtype not in ( np.uint8, - ml_dtypes.float8_e4m3b11fnuz, + ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e4m3fn, ml_dtypes.float8_e5m2fnuz, ml_dtypes.float8_e5m2,