Skip to content

Commit 60dace6

Browse files
convert tensor to ndarray by tensorflow tool
1 parent a88c048 commit 60dace6

File tree

1 file changed

+15
-58
lines changed

1 file changed

+15
-58
lines changed

tf2onnx/utils.py

Lines changed: 15 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from urllib3.util.retry import Retry
2020
import six
2121
import numpy as np
22-
import tensorflow as tf
2322
from tensorflow.core.framework import types_pb2, tensor_pb2
23+
from tensorflow.python.framework import tensor_util
2424
from google.protobuf import text_format
2525
import onnx
2626
from onnx import helper, onnx_pb, defs, numpy_helper
@@ -126,67 +126,24 @@ def split_nodename_and_shape(name):
126126

127127

128128
def tf_to_onnx_tensor(tensor, name=""):
129-
"""
130-
Convert tensorflow tensor to onnx tensor.
131-
Here deal with three types of tensor:
132-
1. normal tensor, e.g., np.array([1,2,3], dtype=DTYPE):
133-
tensor_content: raw data of [1,2,3]
134-
tensor_shape.dim: [3]
135-
DTYPE_val: empty
136-
2. scalar tensor, e.g., np.array(1, dtype=DTYPE):
137-
tensor_content: empty
138-
tensor_shape.dim: [0]
139-
DTYPE_val: 1
140-
3. empty tensor, e.g., np.array([], dtype=DTYPE) and np.array([[]], dtype=DTYPE):
141-
tensor_content: empty
142-
tensor_shape.dim: [0] and [1, 0]
143-
DTYPE_val: empty
144-
"""
145-
new_type = TF_TO_ONNX_DTYPE[tensor.dtype]
146-
tdim = tensor.tensor_shape.dim
147-
dims = [d.size for d in tdim]
148-
is_raw, data = get_tf_tensor_data(tensor)
149-
# empty tensor
150-
if not is_raw and data is None:
151-
np_data = np.array([], dtype=map_onnx_to_numpy_type(new_type)).reshape(dims)
152-
return numpy_helper.from_array(np_data, name=name)
153-
make_sure(data, "tensor data isn't expected to be None or empty")
154-
# scalar tensor
155-
if dims == [0] and not is_raw and len(data) == 1:
156-
return helper.make_tensor(name, new_type, [], data, False)
157-
if not is_raw and len(data) == 1 and np.prod(dims) > 1:
158-
batch_data = np.zeros(dims, dtype=map_onnx_to_numpy_type(new_type))
159-
batch_data.fill(data[0])
160-
return numpy_helper.from_array(batch_data, name=name)
161-
return helper.make_tensor(name, new_type, dims, data, is_raw)
129+
"""Convert tensorflow tensor to onnx tensor."""
130+
np_data = get_tf_tensor_data(tensor)
131+
if np_data.dtype == np.object:
132+
# assume np_data is string, numpy_helper.from_array accepts ndarray,
133+
# in which each item is of str while the whole dtype is of object.
134+
try:
135+
np_data = np_data.astype(np.str).astype(np.object)
136+
except: # pylint: disable=bare-except
137+
raise RuntimeError("Not support type: {}".format(type(np_data.flat[0])))
138+
return numpy_helper.from_array(np_data, name=name)
162139

163140

164141
def get_tf_tensor_data(tensor):
165142
"""Get data from tensor."""
166-
assert isinstance(tensor, tensor_pb2.TensorProto)
167-
is_raw = False
168-
if tensor.tensor_content:
169-
data = tensor.tensor_content
170-
is_raw = True
171-
elif tensor.float_val:
172-
data = tensor.float_val
173-
elif tensor.half_val:
174-
data = tensor.half_val
175-
elif tensor.dcomplex_val:
176-
data = tensor.dcomplex_val
177-
elif tensor.int_val:
178-
data = tensor.int_val
179-
elif tensor.int64_val:
180-
data = tensor.int64_val
181-
elif tensor.bool_val:
182-
data = tensor.bool_val
183-
elif tensor.string_val:
184-
data = tensor.string_val
185-
elif tensor.dtype in [tf.int32, tf.int64, tf.float32, tf.float16]:
186-
data = None
187-
else:
188-
raise ValueError('tensor data not supported')
189-
return [is_raw, data]
143+
make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto")
144+
np_data = tensor_util.MakeNdarray(tensor)
145+
make_sure(isinstance(np_data, np.ndarray), "{} isn't ndarray".format(np_data))
146+
return np_data
190147

191148

192149
def get_shape(node):

0 commit comments

Comments
 (0)