Skip to content

Commit e4a97d1

Browse files
authored
Merge branch 'master' into tflfft
2 parents 05b7f11 + 9e48a44 commit e4a97d1

File tree

9 files changed

+100
-6
lines changed

9 files changed

+100
-6
lines changed

ci_build/azure_pipelines/templates/setup.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ steps:
4141
then
4242
pip install tensorflow-text>=2.5
4343
fi
44+
if [[ $CI_TF_VERSION == 2.6* ]] ;
45+
then
46+
# FIXME: make it >= 2.6 after offical tensorflow-text was released
47+
pip install tensorflow-text==2.6.0rc0
48+
fi
4449
fi
4550
4651
python setup.py install

ci_build/azure_pipelines/unit_test.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
stages:
44
- stage:
55
jobs:
6+
- template: 'templates/job_generator.yml'
7+
parameters:
8+
# tf 2.5
9+
python_versions: ['3.8']
10+
tf_versions: ['2.6.0rc2']
11+
onnx_opsets: ['14']
12+
job:
13+
steps:
14+
- template: 'unit_test.yml'
15+
report_coverage: 'True'
16+
617
- template: 'templates/job_generator.yml'
718
parameters:
819
# TFJS tf 2.5

tests/keras2onnx_applications/nightly_build/test_resnext.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
import numpy as np
88
from mock_keras2onnx.proto import keras
99
from mock_keras2onnx.proto.tfcompat import is_tf2
10-
from keras.regularizers import l2
10+
if is_tf2:
11+
def l2(weight_decay):
12+
# old keras layer expects a tuple but tf keras wants a single value
13+
return keras.regularizers.l2(weight_decay[0])
14+
else:
15+
from keras.regularizers import l2
1116
from os.path import dirname, abspath
1217
sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
1318
img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg')

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,6 +2093,21 @@ def func(shape):
20932093
self.assertTrue(-0.1 < np.mean(results) < 0.1)
20942094
self.assertTrue(0.9 < np.std(results) < 1.1)
20952095

2096+
@check_opset_min_version(10, "TopK")
2097+
def test_random_shuffle(self):
2098+
x_val = make_xval([5, 4, 3])
2099+
def func(x):
2100+
x_ = tf.random.shuffle(x)
2101+
return tf.identity(x_, name=_TFOUTPUT)
2102+
# since results are random, compare the shapes only
2103+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)
2104+
feed_dict = {_INPUT: x_val}
2105+
if "input" in g.input_names:
2106+
# TFLite inputs don't have port numbers
2107+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
2108+
results = self.run_backend(g, g.outputs, feed_dict)
2109+
np.testing.assert_allclose(x_val, np.sort(results[0], axis=0))
2110+
20962111
def test_randomuniform_int(self):
20972112
def func():
20982113
shape = tf.constant([100, 3], name="shape")
@@ -3280,6 +3295,16 @@ def func(x):
32803295
return tf.identity(picks, name=_TFOUTPUT)
32813296
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
32823297

3298+
@check_opset_min_version(9, "IsNaN")
3299+
def test_where_isnan(self):
3300+
x_val = np.array([1, 2, -3, float('nan'), -5, -6, float('nan'), 8, 9, 0], dtype=np.float32)
3301+
true_result = np.array([111, 222, 333, 444, 555, 666, 777, 888, 999, 1000],
3302+
dtype=np.float32)
3303+
def func(x):
3304+
picks = tf.where(is_nan(x), true_result, x)
3305+
return tf.identity(picks, name=_TFOUTPUT)
3306+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3307+
32833308
@check_opset_min_version(9, "Where for strings needs opset 9")
32843309
@skip_tfjs("Technically tf where doesn't support strings and tfjs doesn't like it")
32853310
def test_where_string(self):

tf2onnx/onnx_opset/controlflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@ def version_9(cls, ctx, node, **kwargs):
184184
# We can't use the mul/add trick if a NaN is involved. handles_nan is added earlier in the converter.
185185
handles_nan = node.get_attr_value("handles_nan", False)
186186
if ctx.get_dtype(node.output[0]) in [TensorProto.FLOAT, TensorProto.DOUBLE]:
187+
cond_node = node.inputs[0]
188+
if cond_node.type == "IsNaN":
189+
handles_nan = True
190+
if cond_node.type == "NotEqual" and cond_node.input[0] == cond_node.input[1]:
191+
handles_nan = True
192+
if cond_node.type == "Not" and cond_node.inputs[0].type == "Equal":
193+
eq_node = cond_node.inputs[0]
194+
if eq_node.input[0] == eq_node.input[1]:
195+
handles_nan = True
187196
for inp in node.inputs[1:]:
188197
if inp.is_const() and np.any(np.isnan(inp.get_tensor_value(as_list=False))):
189198
handles_nan = True

tf2onnx/onnx_opset/generator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,28 @@ class PassThroughOp:
109109
def version_1(cls, ctx, node, **kwargs):
110110
pass
111111

112+
@tf_op(["RandomShuffle"])
113+
class RandomShuffleOp:
114+
@classmethod
115+
def version_10(cls, ctx, node, **kwargs):
116+
inp_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
117+
dim_0 = GraphBuilder(ctx).make_slice({'data': inp_shape, 'starts': [0], 'ends': [1], 'axes': [0]})
118+
zeros = ctx.make_node("ConstantOfShape", [dim_0], shapes=[[-1]]).output[0]
119+
120+
seed = node.get_attr_value("seed", 0)
121+
seed2 = node.get_attr_value("seed2", 0)
122+
onnx_seed = utils.combine_seeds(seed, seed2)
123+
rand_attr = {'dtype': onnx_pb.TensorProto.FLOAT}
124+
if onnx_seed is not None:
125+
rand_attr['seed'] = onnx_seed
126+
127+
random_floats = ctx.make_node("RandomUniformLike", [zeros], op_name_scope=node.name, shapes=[[-1]],
128+
attr=rand_attr).output[0]
129+
# Use indices of the TopK to get a random ordering
130+
_, random_ordering = ctx.make_node("TopK", [random_floats, dim_0], output_count=2, attr={'axis': -1}).output
131+
shuffled_res = ctx.make_node("Gather", [node.input[0], random_ordering]).output[0]
132+
ctx.replace_all_inputs(node.output[0], shuffled_res)
133+
112134
@tf_op("Fill")
113135
class Fill:
114136
@classmethod

tf2onnx/onnx_opset/nn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -990,11 +990,10 @@ def version_9(cls, ctx, node, **kwargs):
990990

991991
seed = node.get_attr_value("seed", 0)
992992
seed2 = node.get_attr_value("seed2", 0)
993+
onnx_seed = utils.combine_seeds(seed, seed2)
993994
rand_attr = {}
994-
if seed != 0 or seed2 != 0:
995-
# Produce a unique value depending on both seeds. (diagonal grid traversal)
996-
combined_seed = (seed + seed2 + 1) * (seed + seed2 + 2) // 2 - seed
997-
rand_attr['seed'] = float(combined_seed)
995+
if onnx_seed is not None:
996+
rand_attr['seed'] = onnx_seed
998997

999998
min_aspect_ratio, max_aspect_ratio = node.get_attr_value("aspect_ratio_range", [0.75, 1.33])
1000999
ratio_range = max_aspect_ratio - min_aspect_ratio

tf2onnx/tf_loader.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,18 @@ def make_tensor_proto_wrapped(values, dtype=None, shape=None, verify_shape=False
161161
tensor_util.make_tensor_proto = make_tensor_proto_wrapped
162162

163163
try:
164-
converter_data = _FunctionConverterData(func=func, lower_control_flow=False, aggressive_inlining=True)
164+
function_converter = _FunctionConverterData
165+
if LooseVersion(tf.__version__) >= "2.6.0":
166+
from tensorflow.python.eager import context
167+
from tensorflow.python.framework.convert_to_constants import _FunctionConverterDataInEager, \
168+
_FunctionConverterDataInGraph
169+
if context.executing_eagerly():
170+
function_converter = _FunctionConverterDataInEager
171+
else:
172+
function_converter = _FunctionConverterDataInGraph
173+
else:
174+
function_converter = _FunctionConverterData
175+
converter_data = function_converter(func=func, lower_control_flow=False, aggressive_inlining=True)
165176
frozen_graph_def, _ = _replace_variables_by_constants(converter_data=converter_data)
166177
finally:
167178
tensor_util.make_tensor_proto = make_tensor_proto_original

tf2onnx/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ def make_sure(bool_val, error_msg, *args):
259259
if not bool_val:
260260
raise ValueError("make_sure failure: " + error_msg % args)
261261

262+
def combine_seeds(seed, seed2):
263+
"""Produces an onnx float seed from two tf int seeds. Returns None if both seeds are 0."""
264+
if seed != 0 or seed2 != 0:
265+
# Produce a unique value depending on both seeds. (diagonal grid traversal)
266+
combined_seed = (seed + seed2 + 1) * (seed + seed2 + 2) // 2 - seed
267+
return float(combined_seed)
268+
return None
262269

263270
def topological_sort(dependencies):
264271
"""

0 commit comments

Comments
 (0)