Skip to content

Commit 41fbaa9

Browse files
authored
Merge pull request #3 from onnx/master
update
2 parents 7c5340c + 4d0fced commit 41fbaa9

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ tf2onnx - Convert TensorFlow models to ONNX.
33

44
| Build Type | OS | Python | Tensorflow | Onnx opset | Status |
55
| --- | --- | --- | --- | --- | --- |
6-
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.5, 3.6 | 1.5-1.13 | 7-10 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
7-
| Unit Test - Full | Linux, MacOS, Windows | 3.5, 3.6, 3.7 | 1.5-1.13 | 7-10 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
6+
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.5, 3.6 | 1.5-1.14 | 7-10 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
7+
| Unit Test - Full | Linux, MacOS, Windows | 3.5, 3.6, 3.7 | 1.5-1.14 | 7-10 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
88

9-
<a name="build_status_footnote">\*</a> Only test on python3.6, TF1.13.
9+
<a name="build_status_footnote">\*</a> Only test on python3.6, TF1.14.
1010

1111
# Supported ONNX version
1212
tensorflow-onnx will use the ONNX version installed on your system and installs the latest ONNX version if none is found.
@@ -51,7 +51,7 @@ For pytorch/caffe2, follow the instructions here:
5151
We tested with pytorch/caffe2 and onnxruntime and unit tests are passing for those.
5252

5353
## Supported Tensorflow and Python Versions
54-
We are testing with tensorflow 1.5-1.13 and anaconda **3.5,3.6,3.7**.
54+
We are testing with tensorflow 1.5-1.14 and anaconda **3.5,3.6,3.7**.
5555

5656
# Installation
5757
## From pypi

tf2onnx/optimizer/merge_duplicated_nodes_optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,24 @@ def _del_nodes_if_duplicated(self, nodes_group, graph):
6161
unprocessed_node = []
6262
nodes_to_process = [nodes_group[0]]
6363
for node in nodes_group[1:]:
64-
if self._have_equal_attr(node, nodes_to_process[0]):
64+
if self._have_equal_attr(node, nodes_to_process[0], graph):
6565
nodes_to_process.append(node)
6666
else:
6767
unprocessed_node.append(node)
6868

6969
self._merge_nodes_that_are_duplicated(nodes_to_process, graph)
7070
nodes_group = unprocessed_node
7171

72-
def _have_equal_attr(self, node_1, node_2):
72+
def _have_equal_attr(self, node_1, node_2, graph):
7373
if node_1.attr == node_2.attr:
7474
return True
7575
if node_1.is_const() and node_2.is_const():
76+
# get_tensor_value is costly so that we check their shape first
77+
shape_1 = graph.get_shape(node_1.output[0])
78+
shape_2 = graph.get_shape(node_2.output[0])
79+
if shape_1 is not None and shape_2 is not None and \
80+
shape_1 != shape_2:
81+
return False
7682
const_1 = node_1.get_tensor_value(as_list=False)
7783
const_2 = node_2.get_tensor_value(as_list=False)
7884
if const_1.dtype == const_2.dtype and \

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _handle_node_having_branches(self, node):
199199
# otherwise, it would impact their other output nodes
200200
if self._nodes_has_single_consumer_node(node.inputs):
201201
self._create_transpose_pairs_after_node(node)
202-
input_transposes = node.inputs
202+
input_transposes = set(node.inputs)
203203
for n in input_transposes:
204204
n_input = n.input[0]
205205
utils.make_sure(len(n.output) == 1, "only expect single output")
@@ -371,11 +371,26 @@ def _add_handler(self, trans, node):
371371
# if Conv or ConvTranspose's bias input is not set, then we set, otherwise, we don't set
372372
# todo: maybe we can add already set bias with the input??? try later
373373

374+
if not self._nodes_has_single_consumer_node([t_p]):
375+
self.logger.debug("Conv does not have single consumer, can not merge Conv and Add")
376+
return self._handle_node_having_branches(node)
377+
378+
if not self._nodes_has_single_consumer_node([trans]):
379+
self.logger.debug("input transpose does not have single consumer, skipping...")
380+
return False
381+
374382
target_node = node.inputs[1]
375383
numpy_val = target_node.get_tensor_value(as_list=False)
376384
# Optional 1D bias to be added to the convolution, has size of M
377385
if len(numpy_val.shape) - numpy_val.shape.count(1) > 1:
378386
return self._handle_node_having_branches(node)
387+
388+
rank = len(numpy_val.shape)
389+
utils.make_sure(rank in (1, 4), "only support bias rank = 4 or 1")
390+
# to make rank = 4
391+
if rank == 1:
392+
numpy_val = numpy_val.reshape((1, 1, 1, numpy_val.shape[0]))
393+
379394
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
380395
target_node.set_tensor_value(transposed_val)
381396

0 commit comments

Comments
 (0)