Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions onnx_tf/common/data_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from numbers import Number

import numpy as np
from onnx import mapping
from onnx import helper
from onnx import TensorProto
import tensorflow as tf

Expand Down Expand Up @@ -33,8 +33,8 @@ def tf2onnx(dtype):

onnx_dtype = None
try:
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(
tf_dype.as_numpy_dtype)]
onnx_dtype = helper.np_dtype_to_tensor_dtype(np.dtype(
tf_dype.as_numpy_dtype))
finally:
if onnx_dtype is None:
common.logger.warning(
Expand All @@ -50,11 +50,11 @@ def onnx2tf(dtype):
# to go directly to tf bfloat16 for now.
if dtype == int(TensorProto.BFLOAT16):
return tf.as_dtype("bfloat16")
return tf.as_dtype(mapping.TENSOR_TYPE_TO_NP_TYPE[_onnx_dtype(dtype)])
return tf.as_dtype(helper.tensor_dtype_to_np_dtype(_onnx_dtype(dtype)))


def onnx2field(dtype):
return mapping.STORAGE_TENSOR_TYPE_TO_FIELD[_onnx_dtype(dtype)]
return helper.tensor_dtype_to_field(_onnx_dtype(dtype))


def _onnx_dtype(dtype):
Expand All @@ -75,7 +75,7 @@ def any_dtype_to_onnx_dtype(np_dtype=None, tf_dtype=None, onnx_dtype=None):
sum(num_type_set))

if np_dtype:
onnx_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np_dtype]
onnx_dtype = helper.np_dtype_to_tensor_dtype(np_dtype)
if tf_dtype:
onnx_dtype = tf2onnx(tf_dtype)

Expand Down Expand Up @@ -115,7 +115,7 @@ def is_safe_cast(from_dtype, to_dtype):


def tf_to_np_str(from_type):
return mapping.TENSOR_TYPE_TO_NP_TYPE[int(
return helper.tensor_dtype_to_np_dtype[int(
tf2onnx(from_type))].name if from_type != tf.bfloat16 else 'bfloat16'


Expand Down
4 changes: 2 additions & 2 deletions onnx_tf/handlers/backend/sequence_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from onnx_tf.common import data_type
from onnx import mapping
from onnx import helper


@onnx_op("SequenceEmpty")
class SequenceEmpty(BackendHandler):

@classmethod
def version_11(cls, node, **kwargs):
default_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')]
default_dtype = helper.np_dtype_to_tensor_dtype(np.dtype('float32'))
dtype = data_type.onnx2tf(node.attrs.get("dtype", default_dtype))

ragged = tf.RaggedTensor.from_row_lengths(values=[], row_lengths=[])
Expand Down
8 changes: 5 additions & 3 deletions onnx_tf/pb_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from onnx.helper import make_graph
from onnx.helper import make_tensor
from onnx.helper import make_tensor_value_info
from onnx.helper import mapping
from onnx.helper import tensor_dtype_to_field
from onnx.helper import tensor_dtype_to_storage_tensor_dtype
import tensorflow as tf
from tensorflow.core.framework.attr_value_pb2 import AttrValue
from tensorflow.core.framework.node_def_pb2 import NodeDef
Expand Down Expand Up @@ -426,8 +427,9 @@ def _data_type_caster(cls, protos, data_type_cast_map):
if proto.name in data_type_cast_map:
new_data_type = data_type_cast_map[proto.name]
if type(proto) == TensorProto and proto.data_type != new_data_type:
field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[proto.data_type]]
field = tensor_dtype_to_field(
tensor_dtype_to_storage_tensor_dtype(proto.data_type)
)
vals = getattr(proto, field)
new_proto = make_tensor(
name=proto.name,
Expand Down