Skip to content

Commit dff63d1

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into bench2
2 parents 0b9da56 + d18d3f7 commit dff63d1

File tree

11 files changed

+299
-21
lines changed

11 files changed

+299
-21
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ The common issues we run into we try to document here [Troubleshooting Guide](Tr
1818

1919
| Build Type | OS | Python | Tensorflow | ONNX opset | Status |
2020
| --- | --- | --- | --- | --- | --- |
21-
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.6-3.9 | 1.12-1.15, 2.1-2.5 | 8-15 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
21+
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.6-3.9 | 1.12-1.15, 2.1-2.5 | 8-14 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
2222
| Unit Test - Full | Linux, MacOS, Windows | 3.6-3.9 | 1.12-1.15, 2.1-2.5 | 8-14 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
2323
<br/>
2424

ci_build/azure_pipelines/keras2onnx_unit_test.yml

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
vmImage: 'vs2017-win2016'
120120
strategy:
121121
matrix:
122-
# No python 2.x since no available ONNX package for Windows
122+
############ TF Keras Unit Tests ############
123123
Python36-onnx1.2:
124124
python.version: '3.6'
125125
ONNX_PATH: onnx==1.2.3
@@ -156,7 +156,27 @@ jobs:
156156
TENSORFLOW_PATH: tensorflow-cpu==2.5.0
157157
INSTALL_ORT: pip install onnxruntime==1.8.0
158158

159-
maxParallel: 3
159+
############ Pure Keras Unit Tests ############
160+
Keras-Py36-tf1.15.0:
161+
python.version: '3.6'
162+
ONNX_PATH: onnx==1.5.0
163+
KERAS: keras==2.2.5
164+
TENSORFLOW_PATH: tensorflow==1.15.0
165+
INSTALL_ORT: pip install onnxruntime==1.8.0
166+
167+
Keras-Py37-tf2.0.0:
168+
python.version: '3.7'
169+
ONNX_PATH: onnx==1.7.0
170+
KERAS: keras==2.3.1
171+
TENSORFLOW_PATH: tensorflow==2.0.0
172+
INSTALL_ORT: pip install onnxruntime==1.8.0
173+
174+
Keras-Py37-tf2.2.0:
175+
python.version: '3.7'
176+
ONNX_PATH: onnx==1.9.0
177+
KERAS: keras==2.4.3
178+
TENSORFLOW_PATH: tensorflow==2.2.0
179+
INSTALL_ORT: pip install onnxruntime==1.8.0
160180

161181
steps:
162182
- task: UsePythonVersion@0
@@ -179,6 +199,7 @@ jobs:
179199
pip install protobuf
180200
pip install h5py==2.9.0
181201
pip install %TENSORFLOW_PATH%
202+
IF NOT "%KERAS%"=="" (pip install %KERAS%)
182203
pip install git+https://github.com/microsoft/onnxconverter-common
183204
pip install pytest pytest-cov pytest-runner
184205
%INSTALL_ORT%
@@ -189,6 +210,7 @@ jobs:
189210
pip install -e .
190211
echo Test onnxruntime installation... && python -c "import onnxruntime"
191212
python -c "import onnxconverter_common"
213+
IF NOT "%KERAS%"=="" (set TF_KERAS=0)
192214
pytest keras2onnx_tests --doctest-modules --junitxml=junit/test-results.xml
193215
displayName: 'pytest'
194216

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,6 +3791,18 @@ def func(input_x, boxes, box_ind, corp_size):
37913791
{_INPUT: input_x_val, _INPUT1: boxes_val, _INPUT2: box_ind_val, _INPUT3: corp_size_val},
37923792
rtol=1e-04, atol=1e-03)
37933793

3794+
@check_opset_min_version(11, "CropAndResize")
3795+
def test_crop_and_resize_empty_tensor(self):
3796+
def func(input_x, boxes, box_ind, corp_size):
3797+
return tf.image.crop_and_resize(input_x, boxes, box_ind, corp_size, name=_TFOUTPUT, extrapolation_value=1.0)
3798+
input_x_val = np.random.randint(low=0, high=256, size=[0, 36, 36, 3]).astype(np.float32) # NHWC
3799+
boxes_val = np.array([]).astype(np.float32).reshape([0, 4])
3800+
box_ind_val = np.array([]).astype(np.int32)
3801+
corp_size_val = np.array([40, 40]).astype(np.int32)
3802+
self._run_test_case(func, [_OUTPUT],
3803+
{_INPUT: input_x_val, _INPUT1: boxes_val, _INPUT2: box_ind_val, _INPUT3: corp_size_val},
3804+
rtol=1e-04, atol=1e-03)
3805+
37943806
def test_batch_to_space3d(self):
37953807
block_size = [2, 2]
37963808
crop = [[0, 1], [2, 1]]
@@ -4337,6 +4349,14 @@ def func(x):
43374349
return tf.identity(x_, name=_TFOUTPUT)
43384350
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
43394351

4352+
def test_round_approx(self):
4353+
# In lower opsets there is no Round, but we can approximate it forgoing nearest even
4354+
x_val = np.array([-0.7, -0.5, -0.0, 0.0, +0.0, 0.3, 1.5, 0.7, float('nan')], dtype=np.float32)
4355+
def func(x):
4356+
x_ = tf.round(x)
4357+
return tf.identity(x_, name=_TFOUTPUT)
4358+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
4359+
43404360
@check_opset_min_version(11, "Det")
43414361
@unittest.skip("unclear how this is called in tf-2, fix later")
43424362
def test_determinant(self):

tests/test_lstm.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorflow.python.ops import init_ops
1010
from tensorflow.python.ops import variable_scope
1111
from backend_test_base import Tf2OnnxBackendTestBase
12-
from common import unittest_main, check_opset_after_tf_version, skip_tf2, skip_tf_versions
12+
from common import unittest_main, check_opset_after_tf_version, skip_tf2, skip_tf_versions, check_op_count
1313

1414
from tf2onnx.tf_loader import is_tf2
1515

@@ -36,12 +36,22 @@
3636

3737
class LSTMTests(Tf2OnnxBackendTestBase):
3838

39-
def run_test_case(self, *args, **kwargs): #pylint: disable=arguments-differ
39+
def run_test_case(self, *args, require_lstm_count=1, **kwargs): #pylint: disable=arguments-differ
4040
# TF LSTM has an unknown dim
4141
tmp = self.config.allow_missing_shapes
4242
self.config.allow_missing_shapes = True
43+
def graph_validator(g):
44+
good = True
45+
if "graph_validator" in kwargs:
46+
good = good and kwargs["graph_validator"](g)
47+
if require_lstm_count is None or ":" not in g.outputs[0]:
48+
# Skip checks for tflite graphs (no ":" in outputs)
49+
return good
50+
good = good and check_op_count(g, "LSTM", require_lstm_count, disabled=False)
51+
good = good and check_op_count(g, "Loop", 0, disabled=False)
52+
return good
4353
try:
44-
super().run_test_case(*args, **kwargs)
54+
super().run_test_case(*args, graph_validator=graph_validator, **kwargs)
4555
finally:
4656
self.config.allow_missing_shapes = tmp
4757

@@ -385,7 +395,8 @@ def func(x):
385395
feed_dict = {"input_1:0": x_val}
386396
input_names_with_port = ["input_1:0"]
387397
output_names_with_port = ["output:0", "cell_state:0"]
388-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
398+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
399+
require_lstm_count=2)
389400

390401
@check_opset_after_tf_version("1.15", 8, "might need Scan")
391402
@skip_tf2() # Still failing likely due to inconsistent random number initialization

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
464464
# A list of index, output tuples of potential scan outputs in this graph
465465
# Used by the tflite while loop handler
466466
self.scan_outputs = []
467+
# Used by lstm_tf2_rewriter to indicate this subgraph is an LSTM cell
468+
self.lstm_rewriter_context = None
467469
self.func_inputs = []
468470
self.ragged_variant_list_reads = []
469471
self.ragged_variant_list_writes = []

tf2onnx/onnx_opset/math.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,15 @@ def version_11(cls, ctx, node, **kwargs):
535535

536536
@tf_op("Round")
537537
class Round:
538+
@classmethod
539+
def version_1(cls, ctx, node, **kwargs):
540+
# Not exactly nearest even but close enough
541+
np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))
542+
const_half = ctx.make_const(utils.make_name("const_half"), np.array(0.5, np_dtype)).output[0]
543+
add_node = ctx.make_node("Add", [node.input[0], const_half], op_name_scope=node.name).output[0]
544+
node.type = "Floor"
545+
ctx.replace_inputs(node, [add_node])
546+
538547
@classmethod
539548
def version_11(cls, ctx, node, **kwargs):
540549
pass

tf2onnx/onnx_opset/nn.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,10 @@ def build_dynamic_target_size(ctx, transposed_intput, target_hw):
318318
shape_of_transposed_input = ctx.make_node("Shape", [transposed_intput])
319319
first_half_of_shape = GraphBuilder(ctx).make_slice(
320320
{"data": shape_of_transposed_input.output[0], "ends": [2], "starts": [0]})
321-
target_size_int64 = ctx.make_node("Cast", [target_hw], attr={'to': TensorProto.INT64})
321+
if ctx.get_dtype(target_hw) != TensorProto.INT64:
322+
target_hw = ctx.make_node("Cast", [target_hw], attr={'to': TensorProto.INT64}).output[0]
322323
# We build a tensor containing [n c nh nw]
323-
final_target_size = ctx.make_node("Concat", [first_half_of_shape, target_size_int64.output[0]], {'axis': 0})
324+
final_target_size = ctx.make_node("Concat", [first_half_of_shape, target_hw], {'axis': 0})
324325
return final_target_size
325326

326327

@@ -1183,9 +1184,13 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
11831184
"method").s == b"nearest" else "linear"
11841185
extrapolation_value = float(node.get_attr("extrapolation_value", "0").f)
11851186
input_x = node.input[0]
1187+
x_shape = ctx.make_node("Shape", [input_x]).output[0]
1188+
num_channels = GraphBuilder(ctx).make_slice({"data": x_shape, "starts": [3], "ends": [4], "axes": [0]})
11861189
boxes = node.input[1]
11871190
box_ind = node.input[2]
11881191
crop_size = node.input[3]
1192+
if ctx.get_dtype(crop_size) != TensorProto.INT64:
1193+
crop_size = ctx.make_node("Cast", [crop_size], attr={'to': TensorProto.INT64}).output[0]
11891194
trip_name = utils.make_name(node.name + "_i")
11901195
cond_name = utils.make_name(node.name + "_cond")
11911196
cond_out_name = utils.make_name(node.name + "cond_out")
@@ -1233,6 +1238,10 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
12331238
branches = {"body": g}
12341239
inner_loop = ctx.make_node("Loop", [trip_node.output[0], cond_const.output[0]], name=node.name,
12351240
outputs=node.output, branches=branches)
1241+
const_neg_one = ctx.make_const(utils.make_name("const_neg_one"), np.array([-1], np.int64)).output[0]
1242+
final_shape = ctx.make_node("Concat", [const_neg_one, crop_size, num_channels], attr={'axis': 0}).output[0]
1243+
# This reshape fixes the case when there are no iterations and the scan output is empty.
1244+
ctx.insert_new_node_on_output("Reshape", inner_loop.output[0], inputs=[inner_loop.output[0], final_shape])
12361245

12371246
@classmethod
12381247
def version_11(cls, ctx, node, **kwargs):

tf2onnx/rewriter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tf2onnx.rewriter.quantization_ops_rewriter import rewrite_quantize_and_dequantize
2222
from tf2onnx.rewriter.layer_normalization_rewriter import rewrite_layer_normalization
2323
from tf2onnx.rewriter.ragged_variant_shape_rewriter import rewrite_ragged_variant_shape
24+
from tf2onnx.rewriter.lstm_tf2_rewriter import rewriter_lstm_tf2
2425

2526

2627
__all__ = [
@@ -46,5 +47,6 @@
4647
"rewrite_quantize_and_dequantize",
4748
"rewrite_layer_normalization",
4849
"rewrite_conv_dilations",
49-
"rewrite_ragged_variant_shape"
50+
"rewrite_ragged_variant_shape",
51+
"rewriter_lstm_tf2"
5052
]

0 commit comments

Comments
 (0)