Skip to content

Commit fbb76a3

Browse files
authored
Merge pull request #4653 from m-gupta/main
Issue #4635: Use ml_dtypes to fix onnx_graphsurgeon import error
2 parents a9a797d + d53588c commit fbb76a3

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

tools/onnx-graphsurgeon/onnx_graphsurgeon/exporters/onnx_exporter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from typing import List, Sequence, Union
2020

21+
import ml_dtypes
2122
import numpy as np
2223
import onnx
2324
import onnx.numpy_helper
@@ -131,12 +132,12 @@ def __call__(self, arr):
131132

132133
_NUMPY_ARRAY_CONVERTERS = {
133134
onnx.TensorProto.BFLOAT16: NumpyArrayConverter(
134-
np.uint16, onnx.helper.float32_to_bfloat16
135+
np.uint16, ml_dtypes.bfloat16
135136
),
136137
# FP8 in TensorRT supports negative zeros, no infinities
137138
# See https://onnx.ai/onnx/technical/float8.html#papers
138139
onnx.TensorProto.FLOAT8E4M3FN: NumpyArrayConverter(
139-
np.uint8, lambda x: onnx.helper.float32_to_float8e4m3(x, fn=True, uz=False)
140+
np.uint8, lambda x: ml_dtypes.float8_e4m3fn(x)
140141
),
141142
}
142143

tools/onnx-graphsurgeon/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def no_publish():
2929

3030

3131
REQUIRED_PACKAGES = [
32+
"ml_dtypes",
3233
"numpy",
3334
"onnx>=1.14.0,<=1.16.1",
3435
]

tools/onnx-graphsurgeon/tests/test_exporters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,13 @@ def test_export_constant_tensor_to_value_info_proto(self):
214214
onnx.TensorProto.BFLOAT16,
215215
np.uint16,
216216
0.02,
217-
onnx.numpy_helper.bfloat16_to_float32,
217+
np.float32,
218218
),
219219
(
220220
onnx.TensorProto.FLOAT8E4M3FN,
221221
np.uint8,
222222
0.35,
223-
lambda x, dims: onnx.numpy_helper.float8e4m3_to_float32(x, dims, fn=True, uz=False),
223+
lambda x, dims: np.float32(x),
224224
),
225225
],
226226
)

0 commit comments

Comments
 (0)