Skip to content

Commit 811a3a3

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into profile
2 parents b3b08f0 + 62f1e70 commit 811a3a3

File tree

11 files changed

+127
-46
lines changed

11 files changed

+127
-46
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
293293
onnx graph
294294
"""
295295
```
296-
For example in [examples/call_coverter_via_python.py]():
296+
For example in [examples/call_converter_via_python.py]():
297297
```
298298
import tensorflow as tf
299299
import tf2onnx

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,19 @@ def func(x):
20442044
return tf.identity(x_, name=_TFOUTPUT)
20452045
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
20462046

2047+
@check_tf_min_version("2.0")
2048+
@check_opset_min_version(13, "quantize_and_dequantize")
2049+
def test_qdq_per_channel_signed_input(self):
2050+
x_shape = [3, 3, 2]
2051+
x_val = np.arange(-np.prod(x_shape)/2, np.prod(x_shape)/2).astype("float32").reshape(x_shape)
2052+
def func(x):
2053+
x_ = quantize_and_dequantize(x, np.array([-1.72, -3.89]).astype(np.float32), \
2054+
np.array([5.12, 2.36]).astype(np.float32), \
2055+
signed_input=True, narrow_range=False, \
2056+
range_given=True, axis=-1)
2057+
return tf.identity(x_, name=_TFOUTPUT)
2058+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2059+
20472060
@skip_caffe2_backend()
20482061
@check_opset_min_version(7, "resize_nearest_neighbor")
20492062
def test_resize_nearest_neighbor(self):

tests/test_internals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_node_attr_onnx(self):
226226
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
227227
n1 = g.get_node_by_name("n1")
228228
self.assertTrue("my_attr" in n1.attr)
229-
self.assertTrue("my_attr" not in n1.attr_onnx)
229+
self.assertTrue("my_attr" not in n1.get_onnx_attrs())
230230

231231
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr")
232232
graph_proto = helper.make_graph(
@@ -240,7 +240,7 @@ def test_node_attr_onnx(self):
240240
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
241241
n1 = g.get_node_by_name("n1")
242242
self.assertTrue("my_attr" in n1.attr)
243-
self.assertTrue("my_attr" in n1.attr_onnx)
243+
self.assertTrue("my_attr" in n1.get_onnx_attrs())
244244

245245
def test_tensor_data(self):
246246
tensors = {

tf2onnx/graph.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
# todo(pengwa): remove protected-access later
2929
# pylint: disable=broad-except,protected-access
3030

31+
class ExternalTensorStorage():
32+
"""Passed into graph and node methods to accumulate tensors to save externally"""
33+
def __init__(self):
34+
self.name_to_tensor_data = {}
35+
self.name_counter = 0
36+
self.external_tensor_size_threshold = 1024
37+
self.node_to_modified_value_attr = {}
3138

3239
class Node(object):
3340
"""A Node - wrapper around onnx nodes that we use for graph manipulations."""
@@ -93,16 +100,40 @@ def inputs(self):
93100
def attr(self):
94101
return self._attr
95102

96-
@property
97-
def attr_onnx(self):
98-
"""Return onnx valid attributes"""
103+
def get_value_attr(self, external_tensor_storage=None):
104+
"""Return onnx attr for value property of node.
105+
Attr is modified to point to external tensor data stored in external_tensor_storage, if included.
106+
"""
107+
a = self._attr["value"]
108+
if external_tensor_storage is not None and self in external_tensor_storage.node_to_modified_value_attr:
109+
return external_tensor_storage.node_to_modified_value_attr[self]
110+
if external_tensor_storage is None or a.type != AttributeProto.TENSOR:
111+
return a
112+
if np.product(a.t.dims) > external_tensor_storage.external_tensor_size_threshold:
113+
a = copy.copy(a)
114+
tensor_name = self.name + "_" + str(external_tensor_storage.name_counter)
115+
external_tensor_storage.name_counter += 1
116+
external_tensor_storage.name_to_tensor_data[tensor_name] = a.t.raw_data
117+
external_tensor_storage.node_to_modified_value_attr[self] = a
118+
a.t.raw_data = b'__EXTERNAL'
119+
location = a.t.external_data.add()
120+
location.key = "location"
121+
location.value = tensor_name
122+
a.t.data_location = TensorProto.EXTERNAL
123+
return a
124+
125+
def get_onnx_attrs(self, external_tensor_storage=None):
126+
"""Return onnx valid attributes.
127+
Attrs point to external tensor data stored in external_tensor_storage, if included."""
99128
schema = get_schema(self.type, self.graph.opset, self.domain)
100129
if schema is None and not (self.is_const() or self.is_graph_input()):
101130
logger.debug("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check",
102131
self.name, self.domain, self.type)
103132
onnx_attrs = {}
104133
for a in self._attr.values():
105-
if schema is None or schema.has_attribute(a.name):
134+
if a.name == "value":
135+
onnx_attrs[a.name] = self.get_value_attr(external_tensor_storage)
136+
elif schema is None or schema.has_attribute(a.name):
106137
onnx_attrs[a.name] = a
107138
return onnx_attrs
108139

@@ -333,7 +364,7 @@ def set_body_graph_as_attr(self, attr_name, graph):
333364
self.graph.contained_graphs[self.name].update({attr_name: graph})
334365
graph.parent_graph = self.graph
335366

336-
def update_proto(self):
367+
def update_proto(self, external_tensor_storage=None):
337368
"""Update protobuf from internal structure."""
338369
nodes = list(self._op.input)
339370
for node in nodes:
@@ -351,10 +382,10 @@ def update_proto(self):
351382
attr_graphs = self.get_body_graphs()
352383
if attr_graphs:
353384
for attr_name, sub_graph in attr_graphs.items():
354-
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name)
385+
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name, external_tensor_storage)
355386
self.set_attr(attr_name, graph_proto)
356387

357-
attr = list(self.attr_onnx.values())
388+
attr = list(self.get_onnx_attrs(external_tensor_storage).values())
358389
if attr:
359390
self._op.attribute.extend(attr)
360391

@@ -748,10 +779,10 @@ def update_node_shape_dtype(self, node, override=False):
748779
self.set_shape(output, shape)
749780
logger.debug("Set shape of [%s] to %s", output, shape)
750781

751-
def update_proto(self):
782+
def update_proto(self, external_tensor_storage=None):
752783
"""Update the onnx protobuf from out internal Node structure."""
753784
for node in self._nodes:
754-
node.update_proto()
785+
node.update_proto(external_tensor_storage)
755786

756787
def get_nodes(self):
757788
"""Get node list."""
@@ -968,7 +999,7 @@ def _get_unvisited_child(g, node, not_visited):
968999
ret = [x for _, x in sorted(zip(label, ops))]
9691000
self.reset_nodes(ret)
9701001

971-
def make_graph(self, doc, graph_name=None):
1002+
def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
9721003
"""
9731004
Create GraphProto for onnx from internal graph.
9741005
Args:
@@ -978,7 +1009,7 @@ def make_graph(self, doc, graph_name=None):
9781009
graph_name = graph_name or self.graph_name
9791010
self.delete_unused_nodes(self.outputs)
9801011
self.topological_sort(self.get_nodes())
981-
self.update_proto()
1012+
self.update_proto(external_tensor_storage)
9821013

9831014
# TODO: we'd want to do something like this so that transpose optimizer is active
9841015
# for all (unit) tests
@@ -1021,7 +1052,7 @@ def make_graph(self, doc, graph_name=None):
10211052
# not to use numpy_helper.from_array to create a new tensor
10221053
# because sometimes onnx will have a bug that only check the tensor data in specific field
10231054
# such as at upsample it only checks the float_data field.
1024-
t = op.get_attr("value")
1055+
t = op.get_value_attr(external_tensor_storage)
10251056
tensor = helper.get_attribute_value(t)
10261057
tensor.name = op.output[0]
10271058
initializers.append(tensor)
@@ -1050,14 +1081,14 @@ def make_graph(self, doc, graph_name=None):
10501081

10511082
return graph
10521083

1053-
def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs):
1084+
def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", external_tensor_storage=None, **kwargs):
10541085
"""
10551086
Create final ModelProto for onnx from internal graph.
10561087
Args:
10571088
optimize: optimize graph via onnx
10581089
doc: text for doc string of the model
10591090
"""
1060-
graph = self.make_graph(graph_doc, graph_name)
1091+
graph = self.make_graph(graph_doc, graph_name, external_tensor_storage)
10611092

10621093
if "producer_name" not in kwargs:
10631094
kwargs = {"producer_name": "tf2onnx",

tf2onnx/onnx_opset/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class RealDiv(common.BroadcastOp):
3232
pass
3333

3434

35-
@tf_op(["LeakyRelu", "LogSoftmax", "Softplus", "Softsign"])
35+
@tf_op(["LeakyRelu", "Softplus", "Softsign"])
3636
class DirectOpSinceOpset1:
3737
@classmethod
3838
def version_1(cls, ctx, node, **kwargs):
@@ -185,7 +185,7 @@ def version_8(cls, ctx, node, **kwargs):
185185
def version_12(cls, ctx, node, **kwargs):
186186
node.type = 'Clip' # clip supports all types now
187187

188-
@tf_op("Softmax")
188+
@tf_op(["LogSoftmax", "Softmax"])
189189
class Softmax:
190190
@classmethod
191191
def version_1(cls, ctx, node, **kwargs):

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def _handle_nhwc_tranpose(self, trans):
282282
return False
283283
# move transpose into branches to let Transposes can be "handled" in each branch
284284
for n in out_nodes:
285-
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.attr_onnx)
285+
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.get_onnx_attrs())
286286
n.graph.replace_input(n, trans.output[0], branch_trans.output[0])
287287

288288
self._g.remove_node(trans.name)
@@ -407,7 +407,7 @@ def _add_handler(self, trans, node):
407407
target_node.set_tensor_value(target_val)
408408

409409
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
410-
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
410+
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.get_onnx_attrs())
411411
ops = self._g.get_nodes()
412412
self._g.replace_input(trans, trans.input[0], utils.port_name(conv_node.name), 0)
413413
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])

tf2onnx/rewriter/quantization_ops_rewriter.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
"""
5-
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
5+
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op
66
"""
77

88
import numpy as np
@@ -32,47 +32,65 @@ def create_qdq_nodes(g, match_results):
3232
if not signed_input:
3333
min_quantized, max_quantized = [0, 255]
3434

35+
# Get axis attribute for per channel implementation.
36+
if 'axis' in qdq_node.attr:
37+
axis = qdq_node.attr['axis'].i
38+
3539
# Get the min and max value of the inputs to QDQ op
3640
min_value = extract_numpy_array(qdq_node.inputs[1])
3741
max_value = extract_numpy_array(qdq_node.inputs[2])
3842

39-
# Calculate scales from the min and max values
40-
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
41-
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized
42-
43-
if scale_from_min_side < scale_from_max_side:
44-
scale = scale_from_min_side
45-
else:
46-
scale = scale_from_max_side
47-
48-
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
49-
50-
if signed_input:
51-
zero_point = np.int8(0)
43+
num_channels = min_value.shape[0]
44+
scales = np.zeros(num_channels, dtype=np.float32)
45+
zero_point_dtype = np.int8 if signed_input else np.uint8
46+
zero_point = np.zeros(num_channels, dtype=zero_point_dtype)
47+
48+
for i in range(num_channels):
49+
# Calculate scales from the min and max values
50+
scale_from_min_side = min_quantized/min_value[i] if min_quantized*min_value[i] > 0 else max_quantized
51+
scale_from_max_side = max_quantized/max_value[i] if max_quantized*max_value[i] > 0 else max_quantized
52+
53+
if scale_from_min_side < scale_from_max_side:
54+
scale = scale_from_min_side
55+
else:
56+
scale = scale_from_max_side
57+
58+
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
59+
scales[i] = np.float32(scale)
60+
61+
# Set scalars for scale and zero point for per layer quantization
62+
if num_channels == 1:
63+
scales = scales[0]
64+
zero_point = zero_point[0]
65+
attrs = {}
5266
else:
53-
zero_point = np.uint8(0)
67+
utils.make_sure(axis, "Axis must be specified for per channel quantization")
68+
attrs = {'axis': axis}
5469

5570
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
56-
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale)
71+
inverse_scale = (1/scales).astype(np.float32)
72+
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=inverse_scale)
5773
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
5874
quant_node = g.make_node(op_type="QuantizeLinear",
5975
inputs=[qdq_node.input[0], y_quant_scale.output[0],
6076
y_zero_point.output[0]],
6177
shapes=[qdq_node_output_shape],
78+
attr=attrs,
6279
dtypes=[qdq_node_output_dtype],
6380
name=utils.make_name("QuantLinearNode"))
6481

6582
g.set_shape(quant_node.output[0], qdq_node_output_shape)
6683

6784
g.remove_node(qdq_node.name)
6885

69-
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale)
86+
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=inverse_scale)
7087
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
7188
dequant_node = g.make_node(op_type="DequantizeLinear",
7289
inputs=[quant_node.output[0], y_dequant_scale.output[0],
7390
y_inv_zero_point.output[0]],
7491
outputs=[qdq_node.output[0]],
7592
shapes=[qdq_node_output_shape],
93+
attr=attrs,
7694
dtypes=[qdq_node_output_dtype],
7795
name=utils.make_name("DequantLinearNode"))
7896
g.set_shape(dequant_node.output[0], qdq_node_output_shape)

tf2onnx/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def build_onnx_op(node):
136136
copied_sub_graph = copy.deepcopy(sub_graph)
137137
graph_proto = copied_sub_graph.make_graph("graph for " + node.name + " " + attr_name)
138138
attr.append(helper.make_attribute(attr_name, graph_proto))
139-
attr.extend(node.attr_onnx.values())
139+
attr.extend(node.get_onnx_attrs().values())
140140
if attr:
141141
onnx_node.attribute.extend(attr)
142142
return onnx_node

tf2onnx/tf_utils.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,24 @@ def get_tf_node_attr(node, name):
124124
def get_tf_version():
125125
return LooseVersion(tf.__version__)
126126

127-
128-
def tflist_to_onnx(g, shape_override):
127+
def compress_graph_def(graph_def):
128+
"""
129+
Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.
130+
"""
131+
node_defs = list(graph_def.node)
132+
const_node_values = {}
133+
for node_def in node_defs:
134+
if node_def.op == 'Const':
135+
tensor = node_def.attr["value"].tensor
136+
# Small constants are sometimes used to store shape information and must be maintained
137+
if len(tensor.tensor_content) > 1000:
138+
make_sure(node_def.name not in const_node_values, "Two nodes in graph have same name %s", node_def.name)
139+
const_node_values[node_def.name] = tensor.tensor_content
140+
tensor.tensor_content = b''
141+
return const_node_values
142+
143+
144+
def tflist_to_onnx(g, shape_override, const_node_values=None):
129145
"""
130146
Convert the tf-node list into an onnx graph with minimal rewrites so
131147
we can use the onnx graph as intermediate graph.
@@ -193,7 +209,10 @@ def tflist_to_onnx(g, shape_override):
193209
attr[a] = nattr.name
194210
functions[nattr.name] = input_shapes
195211
elif a == "value":
196-
onnx_tensor = tf_to_onnx_tensor(get_tf_node_attr(node, a), name=port_name(node.name))
212+
tensor = get_tf_node_attr(node, a)
213+
if const_node_values and node.name in const_node_values:
214+
tensor.tensor_content = const_node_values[node.name]
215+
onnx_tensor = tf_to_onnx_tensor(tensor, name=port_name(node.name))
197216
attr[a] = onnx_tensor
198217
elif a == "DstT":
199218
attr["to"] = map_tf_dtype(get_tf_node_attr(node, "DstT"))
@@ -217,8 +236,8 @@ def tflist_to_onnx(g, shape_override):
217236
return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions
218237

219238

220-
def tensorflow_to_onnx(graph, shape_override):
239+
def tensorflow_to_onnx(graph, shape_override, const_node_values=None):
221240
"""
222241
Load tensorflow graph and do a conversion.
223242
"""
224-
return tflist_to_onnx(graph, shape_override)
243+
return tflist_to_onnx(graph, shape_override, const_node_values)

0 commit comments

Comments
 (0)