Skip to content

Commit 22c8495

Browse files
committed
Replace onnx.mapping reference with helper
1 parent ee0c5e5 commit 22c8495

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

onnx_tf/common/data_type.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from numbers import Number
22

33
import numpy as np
4-
from onnx import mapping
4+
from onnx import helper
55
from onnx import TensorProto
66
import tensorflow as tf
77

@@ -33,8 +33,8 @@ def tf2onnx(dtype):
3333

3434
onnx_dtype = None
3535
try:
36-
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(
37-
tf_dype.as_numpy_dtype)]
36+
onnx_dtype = helper.np_dtype_to_tensor_dtype(np.dtype(
37+
tf_dype.as_numpy_dtype))
3838
finally:
3939
if onnx_dtype is None:
4040
common.logger.warning(
@@ -50,11 +50,11 @@ def onnx2tf(dtype):
5050
# to go directly to tf bfloat16 for now.
5151
if dtype == int(TensorProto.BFLOAT16):
5252
return tf.as_dtype("bfloat16")
53-
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
53+
return tf.as_dtype(helper.tensor_dtype_to_np_dtype(_onnx_dtype(dtype)))
5454

5555

5656
def onnx2field(dtype):
57-
return mapping.STORAGE_TENSOR_TYPE_TO_FIELD[_onnx_dtype(dtype)]
57+
return helper.tensor_dtype_to_field(_onnx_dtype(dtype))
5858

5959

6060
def _onnx_dtype(dtype):
@@ -75,7 +75,7 @@ def any_dtype_to_onnx_dtype(np_dtype=None, tf_dtype=None, onnx_dtype=None):
7575
sum(num_type_set))
7676

7777
if np_dtype:
78-
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np_dtype]
78+
onnx_dtype = helper.np_dtype_to_tensor_dtype(np_dtype)
7979
if tf_dtype:
8080
onnx_dtype = tf2onnx(tf_dtype)
8181

@@ -115,7 +115,7 @@ def is_safe_cast(from_dtype, to_dtype):
115115

116116

117117
def tf_to_np_str(from_type):
118-
return mapping.TENSOR_TYPE_TO_NP_TYPE[int(
118+
return helper.tensor_dtype_to_np_dtype[int(
119119
tf2onnx(from_type))].name if from_type != tf.bfloat16 else 'bfloat16'
120120

121121

onnx_tf/handlers/backend/sequence_empty.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from onnx_tf.handlers.backend_handler import BackendHandler
55
from onnx_tf.handlers.handler import onnx_op
66
from onnx_tf.common import data_type
7-
from onnx import mapping
7+
from onnx import helper
88

99

1010
@onnx_op("SequenceEmpty")
1111
class SequenceEmpty(BackendHandler):
1212

1313
@classmethod
1414
def version_11(cls, node, **kwargs):
15-
default_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')]
15+
default_dtype = helper.np_dtype_to_tensor_dtype(np.dtype('float32'))
1616
dtype = data_type.onnx2tf(node.attrs.get("dtype", default_dtype))
1717

1818
ragged = tf.RaggedTensor.from_row_lengths(values=[], row_lengths=[])

onnx_tf/pb_wrapper.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from onnx.helper import make_graph
1010
from onnx.helper import make_tensor
1111
from onnx.helper import make_tensor_value_info
12-
from onnx.helper import mapping
12+
from onnx.helper import tensor_dtype_to_field
13+
from onnx.helper import tensor_dtype_to_storage_tensor_dtype
1314
import tensorflow as tf
1415
from tensorflow.core.framework.attr_value_pb2 import AttrValue
1516
from tensorflow.core.framework.node_def_pb2 import NodeDef
@@ -426,8 +427,9 @@ def _data_type_caster(cls, protos, data_type_cast_map):
426427
if proto.name in data_type_cast_map:
427428
new_data_type = data_type_cast_map[proto.name]
428429
if type(proto) == TensorProto and proto.data_type != new_data_type:
429-
field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
430-
mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[proto.data_type]]
430+
field = tensor_dtype_to_field(
431+
tensor_dtype_to_storage_tensor_dtype(proto.data_type)
432+
)
431433
vals = getattr(proto, field)
432434
new_proto = make_tensor(
433435
name=proto.name,

0 commit comments

Comments
 (0)