Skip to content

Commit 61ff34f

Browse files
authored
Merge pull request #524 from onnx/gs/opt-relu6
optimize relu6
2 parents 0f880fc + 965e988 commit 61ff34f

File tree

4 files changed

+16
-17
lines changed

4 files changed

+16
-17
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ tf2onnx - convert TensorFlow models to ONNX models.
66
# Supported ONNX version
77
tensorflow-onnx will use the ONNX version installed on your system and installs the latest ONNX version if none is found.
88

9-
By default we use opset 7 for the resulting ONNX graph since most runtimes will support opset 7. Opset 7 was introduced in onnx-1.2.
9+
We support opset 6 to 10. By default we use opset 7 for the resulting ONNX graph since most runtimes will support opset 7.
1010

11-
Newer releases of ONNX support higher opsets. For example, to create an ONNX graph for opset 8 use in the command line ```--opset 8```.
11+
If you want the graph to be generated with a newer opset, use ```--opset``` in the command line, for example ```--opset 10```.
1212

1313
# Status
1414
We support many TensorFlow models. Support for Fully Connected and Convolutional networks is mature. Dynamic LSTM/GRU/Attention networks should work but the code for this is evolving.
@@ -41,7 +41,7 @@ For pytorch/caffe2, follow the instructions here:
4141
We tested with pytorch/caffe2 and onnxruntime and unit tests are passing for those.
4242

4343
## Supported Tensorflow and Python Versions
44-
We tested with tensorflow 1.5-1.13 and anaconda **3.5,3.6**.
44+
We are testing with tensorflow 1.5-1.13 and anaconda **3.5,3.6,3.7**.
4545

4646
# Installation
4747
## From pypi
@@ -55,7 +55,7 @@ python setup.py install
5555
or
5656
python setup.py develop
5757
```
58-
tensorflow-onnx requires onnx-1.2.2 or better and will install/upgrade onnx if needed.
58+
tensorflow-onnx requires onnx-1.5 or better and will install/upgrade onnx if needed.
5959

6060
To create a distribution:
6161
```
@@ -69,10 +69,10 @@ names with ```--inputs INPUTS``` and ```--outputs OUTPUTS```.
6969

7070
```
7171
python -m tf2onnx.convert
72-
--input SOURCE_GRAPHDEF_PB
73-
--graphdef SOURCE_GRAPHDEF_PB
74-
--checkpoint SOURCE_CHECKPOINT
75-
--saved-model SOURCE_SAVED_MODEL
72+
[--input SOURCE_GRAPHDEF_PB]
73+
[--graphdef SOURCE_GRAPHDEF_PB]
74+
[--checkpoint SOURCE_CHECKPOINT]
75+
[--saved-model SOURCE_SAVED_MODEL]
7676
[--output TARGET_ONNX_MODEL]
7777
[--inputs GRAPH_INPUTS]
7878
[--outputs GRAPH_OUTPUS]

tests/test_backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -900,10 +900,9 @@ def test_tanh(self):
900900
_ = tf.identity(x_, name=_TFOUTPUT)
901901
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
902902

903-
@check_onnxruntime_incompatibility("Max")
904903
def test_relu6(self):
905-
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
906-
x = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
904+
x_val = np.array([0.5, 1.0, -0.5, -1.0, 6, 7], dtype=np.float32).reshape((2, 3))
905+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
907906
x_ = tf.nn.relu6(x)
908907
_ = tf.identity(x_, name=_TFOUTPUT)
909908
self._run_test_case([_OUTPUT], {_INPUT: x_val})

tests/test_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def test_relu6(self):
249249
_ = tf.identity(x_, name="output")
250250
g = process_tf_graph(sess.graph, opset=self.config.opset)
251251
self.assertEqual(
252-
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Relu] Relu6__2 [op_type=Clip] '
253-
'output [op_type=Identity] input1:0 -> Relu6 Relu6:0 -> Relu6__2 Relu6__2:0 -> output }',
252+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Clip] output [op_type=Identity] '
253+
'input1:0 -> Relu6 Relu6:0 -> output }',
254254
onnx_to_graphviz(g))
255255

256256
def test_conv2d(self):

tf2onnx/onnx_opset/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ class Relu6:
135135
@classmethod
136136
def version_4(cls, ctx, node, **kwargs):
137137
# relu6 = min(max(features, 0), 6)
138-
node.type = "Relu"
139-
clip_name = utils.make_name(node.name)
140-
clip_node = ctx.insert_new_node_on_output("Clip", node.output[0], name=clip_name, min=0.0, max=6.0)
141-
ctx.copy_shape(node.output[0], clip_node.output[0])
138+
# relu6 = min(max(features, 0), 6)
139+
node.type = "Clip"
140+
node.set_attr("min", 0.0)
141+
node.set_attr("max", 6.0)
142142

143143

144144
@tf_op("Rsqrt")

0 commit comments

Comments
 (0)