Skip to content

Commit fac2c91

Browse files
Tom/tf 2.5 (#1462)
* add tf-2.5 to ci Signed-off-by: Guenther Schmuelling <[email protected]> * theta must be >- 0 Signed-off-by: Guenther Schmuelling <[email protected]> * Update tflite flatbuffer Signed-off-by: Tom Wildenhain <[email protected]> * Fixes for tf2.5 Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 4cfc025 commit fac2c91

26 files changed

+482
-17
lines changed

ci_build/azure_pipelines/onnxruntime_nightly_test.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ stages:
3939
- template: 'unit_test.yml'
4040
report_coverage: 'True'
4141

42+
- template: 'templates/job_generator.yml'
43+
parameters:
44+
platforms: ['linux', 'windows']
45+
python_versions: [3.8']
46+
tf_versions: ['2.5.0rc1']
47+
onnx_opsets: ['']
48+
onnx_backends: {onnxruntime: ['nightly']}
49+
job:
50+
steps:
51+
- template: 'unit_test.yml'
52+
report_coverage: 'True'
53+
4254
- template: 'templates/combine_test_coverage.yml'
4355

4456
schedules:

ci_build/azure_pipelines/pretrained_model_test-matrix.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,12 @@ jobs:
2727
job:
2828
steps:
2929
- template: 'pretrained_model_test.yml'
30-
30+
31+
- template: 'templates/job_generator.yml'
32+
parameters:
33+
platforms: ['linux', 'windows']
34+
python_versions: ['3.8']
35+
tf_versions: ['2.5.0rc1']
36+
job:
37+
steps:
38+
- template: 'pretrained_model_test.yml'

ci_build/azure_pipelines/unit_test-matrix.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ stages:
77
parameters:
88
platforms: ['linux', 'windows']
99
python_versions: ['3.6']
10-
tf_versions: ['1.13.1', '1.12.3']
10+
tf_versions: ['1.12.3']
1111
onnx_opsets: ['']
1212
job:
1313
steps:
@@ -35,5 +35,16 @@ stages:
3535
steps:
3636
- template: 'unit_test.yml'
3737
report_coverage: 'True'
38+
39+
- template: 'templates/job_generator.yml'
40+
parameters:
41+
platforms: ['linux', 'windows']
42+
python_versions: ['3.8']
43+
tf_versions: ['2.5.0rc1']
44+
onnx_opsets: ['']
45+
job:
46+
steps:
47+
- template: 'unit_test.yml'
48+
report_coverage: 'True'
3849

3950
- template: 'templates/combine_test_coverage.yml'

ci_build/azure_pipelines/unit_test.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ stages:
66
- template: 'templates/job_generator.yml'
77
parameters:
88
python_versions: ['3.8']
9-
tf_versions: ['2.4.0']
9+
tf_versions: ['2.5.0rc1']
1010
onnx_opsets: ['']
1111
skip_tflite_tests: 'False'
1212
skip_tf_tests: 'True'
@@ -30,7 +30,7 @@ stages:
3030
- template: 'templates/job_generator.yml'
3131
parameters:
3232
python_versions: ['3.8']
33-
tf_versions: ['2.4.0']
33+
tf_versions: ['2.5.0rc1']
3434
onnx_opsets: ['']
3535
job:
3636
steps:
@@ -40,7 +40,7 @@ stages:
4040
- template: 'templates/job_generator.yml'
4141
parameters:
4242
python_versions: ['3.7']
43-
tf_versions: ['1.14.0','1.15.2','2.2.0','2.3.0']
43+
tf_versions: ['1.14.0','1.15.2','2.3.0','2.4.1']
4444
onnx_opsets: ['']
4545
job:
4646
steps:
@@ -71,7 +71,7 @@ stages:
7171
parameters:
7272
python_versions: ['3.7']
7373
platforms: ['windows']
74-
tf_versions: ['2.3.0']
74+
tf_versions: ['2.4.1']
7575
onnx_opsets: ['']
7676
job:
7777
steps:

tests/backend_test_base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tf2onnx.tf_loader import tf_optimize, is_tf2, get_hash_table_info
3131
from tf2onnx.tf_utils import compress_graph_def
3232
from tf2onnx.graph import ExternalTensorStorage
33+
from tf2onnx.tflite.Model import Model
3334

3435

3536
if is_tf2():
@@ -217,6 +218,22 @@ def convert_to_tflite(self, graph_def, feed_dict, outputs):
217218
except ConverterError:
218219
return None
219220

221+
def tflite_has_supported_types(self, tflite_path):
222+
try:
223+
with open(tflite_path, 'rb') as f:
224+
buf = f.read()
225+
buf = bytearray(buf)
226+
model = Model.GetRootAsModel(buf, 0)
227+
tensor_cnt = model.Subgraphs(0).TensorsLength()
228+
interpreter = tf.lite.Interpreter(tflite_path)
229+
for i in range(tensor_cnt):
230+
dtype = interpreter._get_tensor_details(i)['dtype'] # pylint: disable=protected-access
231+
if np.dtype(dtype).kind == 'O':
232+
return False
233+
return True
234+
except (RuntimeError, ValueError):
235+
return False
236+
220237
def run_tflite(self, tflite_path, feed_dict):
221238
try:
222239
interpreter = tf.lite.Interpreter(tflite_path)
@@ -293,7 +310,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
293310

294311
if test_tflite:
295312
tflite_path = self.convert_to_tflite(graph_def, feed_dict, output_names_with_port)
296-
test_tflite = tflite_path is not None
313+
test_tflite = tflite_path is not None and self.tflite_has_supported_types(tflite_path)
297314

298315
if test_tf:
299316
tf_reset_default_graph()

tests/test_backend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3328,6 +3328,16 @@ def func(x, y):
33283328

33293329
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y})
33303330
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y})
3331+
3332+
@check_opset_min_version(8, "BroadcastTo")
3333+
def test_zeros_like_bool(self):
3334+
input_x = np.random.random_sample([10, 20]).astype(np.float32)
3335+
input_y = np.array([20, 10]).astype(np.int64)
3336+
3337+
def func(x, y):
3338+
z = tf.reshape(x, y)
3339+
return tf.zeros_like(z, name=_TFOUTPUT)
3340+
33313341
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x > 0.5, _INPUT1: input_y})
33323342

33333343
@check_opset_min_version(9, "is_nan")
@@ -3746,7 +3756,7 @@ def test_conv1d_5(self):
37463756
def test_thresholded_relu(self):
37473757
# tf.keras.layers.ThresholdedReLU only supports `float32` for x
37483758
x_val = np.array([0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 1.5, -1.5], dtype=np.float32).reshape((3, 3))
3749-
theta_vals = [-1.0, -0.5, 0.0, 0.5, 1.0]
3759+
theta_vals = [0.0, 0.5, 1.0, 2.0]
37503760
for theta_val in theta_vals:
37513761
def func(x):
37523762
t = tf.keras.layers.ThresholdedReLU(theta=theta_val)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,15 @@ def version_9(cls, ctx, node, **kwargs):
183183
# T1 output = Where(bool condition, T1 x, T1 y)
184184
# NOTE: condition can be 1-dimension in tensorflow, while in onnx,
185185
# it should be broadcastable with other two inputs
186-
if ctx.get_dtype(node.output[0]) != TensorProto.STRING:
186+
187+
# We can't use the mul/add trick if a NaN is involved. handles_nan is added earlier in the converter.
188+
handles_nan = node.get_attr_value("handles_nan", False)
189+
if ctx.get_dtype(node.output[0]) in [TensorProto.FLOAT, TensorProto.DOUBLE]:
190+
for inp in node.inputs[1:]:
191+
if inp.is_const() and np.any(np.isnan(inp.get_tensor_value(as_list=False))):
192+
handles_nan = True
193+
194+
if ctx.get_dtype(node.output[0]) != TensorProto.STRING and not handles_nan:
187195
# Due to bad ORT implementation, Mul/Add ops are faster than Where op
188196
cls.version_7(ctx, node, **kwargs)
189197
return

tf2onnx/onnx_opset/logical.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def version_1(cls, ctx, node, **kwargs):
7474
def version_7(cls, ctx, node, **kwargs):
7575
# T2 output = Equal(T1, x, T1 y), T1 \in {bool, int32, int64}
7676
need_not = node.type == "NotEqual"
77+
if need_not and node.input[0] == node.input[1]:
78+
# The only value not equal to itself is NaN
79+
node.type = "IsNaN"
80+
ctx.replace_inputs(node, [node.input[0]])
81+
return
7782
supported_dtypes = [
7883
TensorProto.BOOL,
7984
TensorProto.INT32,

tf2onnx/onnx_opset/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,13 @@ def version_10(cls, ctx, node, **kwargs):
20652065
batch_dim = node.get_attr_value("batch_dim", 0)
20662066

20672067
ctx.remove_node(node.name)
2068+
shape = ctx.get_shape(node.input[0])
2069+
dtype = ctx.get_dtype(node.input[0])
20682070
node = ctx.make_node(
20692071
"ReverseSequence",
20702072
node.input,
20712073
outputs=node.output,
2072-
attr={"batch_axis": batch_dim, "time_axis": seq_dim})
2074+
attr={"batch_axis": batch_dim, "time_axis": seq_dim}, shapes=[shape], dtypes=[dtype])
20732075

20742076
seq_len_dtype = ctx.get_dtype(node.input[1])
20752077
utils.make_sure(seq_len_dtype is not None, "dtype of {} is None".format(node.input[1]))

tf2onnx/tflite/BuiltinOperator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,13 @@ class BuiltinOperator(object):
137137
CALL_ONCE = 129
138138
BROADCAST_TO = 130
139139
RFFT2D = 131
140+
CONV_3D = 132
141+
IMAG = 133
142+
REAL = 134
143+
COMPLEX_ABS = 135
144+
HASHTABLE = 136
145+
HASHTABLE_FIND = 137
146+
HASHTABLE_IMPORT = 138
147+
HASHTABLE_SIZE = 139
148+
REDUCE_ALL = 140
140149

0 commit comments

Comments
 (0)