|
19 | 19 | from urllib3.util.retry import Retry
|
20 | 20 | import six
|
21 | 21 | import numpy as np
|
22 |
| -import tensorflow as tf |
23 | 22 | from tensorflow.core.framework import types_pb2, tensor_pb2
|
| 23 | +from tensorflow.python.framework import tensor_util |
24 | 24 | from google.protobuf import text_format
|
25 | 25 | import onnx
|
26 | 26 | from onnx import helper, onnx_pb, defs, numpy_helper
|
@@ -126,67 +126,24 @@ def split_nodename_and_shape(name):
|
126 | 126 |
|
127 | 127 |
|
128 | 128 | def tf_to_onnx_tensor(tensor, name=""):
|
129 |
| - """ |
130 |
| - Convert tensorflow tensor to onnx tensor. |
131 |
| - Here deal with three types of tensor: |
132 |
| - 1. normal tensor, e.g., np.array([1,2,3], dtype=DTYPE): |
133 |
| - tensor_content: raw data of [1,2,3] |
134 |
| - tensor_shape.dim: [3] |
135 |
| - DTYPE_val: empty |
136 |
| - 2. scalar tensor, e.g., np.array(1, dtype=DTYPE): |
137 |
| - tensor_content: empty |
138 |
| - tensor_shape.dim: [0] |
139 |
| - DTYPE_val: 1 |
140 |
| - 3. empty tensor, e.g., np.array([], dtype=DTYPE) and np.array([[]], dtype=DTYPE): |
141 |
| - tensor_content: empty |
142 |
| - tensor_shape.dim: [0] and [1, 0] |
143 |
| - DTYPE_val: empty |
144 |
| - """ |
145 |
| - new_type = TF_TO_ONNX_DTYPE[tensor.dtype] |
146 |
| - tdim = tensor.tensor_shape.dim |
147 |
| - dims = [d.size for d in tdim] |
148 |
| - is_raw, data = get_tf_tensor_data(tensor) |
149 |
| - # empty tensor |
150 |
| - if not is_raw and data is None: |
151 |
| - np_data = np.array([], dtype=map_onnx_to_numpy_type(new_type)).reshape(dims) |
152 |
| - return numpy_helper.from_array(np_data, name=name) |
153 |
| - make_sure(data, "tensor data isn't expected to be None or empty") |
154 |
| - # scalar tensor |
155 |
| - if dims == [0] and not is_raw and len(data) == 1: |
156 |
| - return helper.make_tensor(name, new_type, [], data, False) |
157 |
| - if not is_raw and len(data) == 1 and np.prod(dims) > 1: |
158 |
| - batch_data = np.zeros(dims, dtype=map_onnx_to_numpy_type(new_type)) |
159 |
| - batch_data.fill(data[0]) |
160 |
| - return numpy_helper.from_array(batch_data, name=name) |
161 |
| - return helper.make_tensor(name, new_type, dims, data, is_raw) |
| 129 | + """Convert tensorflow tensor to onnx tensor.""" |
| 130 | + np_data = get_tf_tensor_data(tensor) |
| 131 | + if np_data.dtype == np.object: |
| 132 | + # assume np_data is string, numpy_helper.from_array accepts ndarray, |
| 133 | + # in which each item is of str while the whole dtype is of object. |
| 134 | + try: |
| 135 | + np_data = np_data.astype(np.str).astype(np.object) |
| 136 | + except: # pylint: disable=bare-except |
| 137 | + raise RuntimeError("Not support type: {}".format(type(np_data.flat[0]))) |
| 138 | + return numpy_helper.from_array(np_data, name=name) |
162 | 139 |
|
163 | 140 |
|
164 | 141 | def get_tf_tensor_data(tensor):
|
165 | 142 | """Get data from tensor."""
|
166 |
| - assert isinstance(tensor, tensor_pb2.TensorProto) |
167 |
| - is_raw = False |
168 |
| - if tensor.tensor_content: |
169 |
| - data = tensor.tensor_content |
170 |
| - is_raw = True |
171 |
| - elif tensor.float_val: |
172 |
| - data = tensor.float_val |
173 |
| - elif tensor.half_val: |
174 |
| - data = tensor.half_val |
175 |
| - elif tensor.dcomplex_val: |
176 |
| - data = tensor.dcomplex_val |
177 |
| - elif tensor.int_val: |
178 |
| - data = tensor.int_val |
179 |
| - elif tensor.int64_val: |
180 |
| - data = tensor.int64_val |
181 |
| - elif tensor.bool_val: |
182 |
| - data = tensor.bool_val |
183 |
| - elif tensor.string_val: |
184 |
| - data = tensor.string_val |
185 |
| - elif tensor.dtype in [tf.int32, tf.int64, tf.float32, tf.float16]: |
186 |
| - data = None |
187 |
| - else: |
188 |
| - raise ValueError('tensor data not supported') |
189 |
| - return [is_raw, data] |
| 143 | + make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto") |
| 144 | + np_data = tensor_util.MakeNdarray(tensor) |
| 145 | + make_sure(isinstance(np_data, np.ndarray), "{} isn't ndarray".format(np_data)) |
| 146 | + return np_data |
190 | 147 |
|
191 | 148 |
|
192 | 149 | def get_shape(node):
|
|
0 commit comments