diff --git a/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py b/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py index a39ebe22..2dd2538d 100644 --- a/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -18,6 +18,7 @@ from typing import List, Sequence, Union +import ml_dtypes import numpy as np import onnx import onnx.numpy_helper @@ -131,12 +132,12 @@ def __call__(self, arr): _NUMPY_ARRAY_CONVERTERS = { onnx.TensorProto.BFLOAT16: NumpyArrayConverter( - np.uint16, onnx.helper.float32_to_bfloat16 + np.uint16, ml_dtypes.bfloat16 ), # FP8 in TensorRT supports negative zeros, no infinities # See https://onnx.ai/onnx/technical/float8.html#papers onnx.TensorProto.FLOAT8E4M3FN: NumpyArrayConverter( - np.uint8, lambda x: onnx.helper.float32_to_float8e4m3(x, fn=True, uz=False) + np.uint8, lambda x: ml_dtypes.float8_e4m3fn(x) ), } diff --git a/tools/onnx-graphsurgeon/setup.py b/tools/onnx-graphsurgeon/setup.py index a0ab7bf6..2eb389d9 100644 --- a/tools/onnx-graphsurgeon/setup.py +++ b/tools/onnx-graphsurgeon/setup.py @@ -29,6 +29,7 @@ def no_publish(): REQUIRED_PACKAGES = [ + "ml_dtypes", "numpy", "onnx>=1.14.0,<=1.16.1", ] diff --git a/tools/onnx-graphsurgeon/tests/test_exporters.py b/tools/onnx-graphsurgeon/tests/test_exporters.py index 5af002d0..eaffb8e2 100644 --- a/tools/onnx-graphsurgeon/tests/test_exporters.py +++ b/tools/onnx-graphsurgeon/tests/test_exporters.py @@ -214,13 +214,13 @@ def test_export_constant_tensor_to_value_info_proto(self): onnx.TensorProto.BFLOAT16, np.uint16, 0.02, - onnx.numpy_helper.bfloat16_to_float32, + np.float32, ), ( onnx.TensorProto.FLOAT8E4M3FN, np.uint8, 0.35, - lambda x, dims: onnx.numpy_helper.float8e4m3_to_float32(x, dims, fn=True, uz=False), + lambda x, dims: np.float32(x), ), ], )