Skip to content

Commit e00594d

Browse files
Implement RandomShuffle op (#1658)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent a9f636f commit e00594d

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,6 +2093,21 @@ def func(shape):
20932093
self.assertTrue(-0.1 < np.mean(results) < 0.1)
20942094
self.assertTrue(0.9 < np.std(results) < 1.1)
20952095

2096+
@check_opset_min_version(10, "TopK")
2097+
def test_random_shuffle(self):
2098+
x_val = make_xval([5, 4, 3])
2099+
def func(x):
2100+
x_ = tf.random.shuffle(x)
2101+
return tf.identity(x_, name=_TFOUTPUT)
2102+
# since results are random, compare the shapes only
2103+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)
2104+
feed_dict = {_INPUT: x_val}
2105+
if "input" in g.input_names:
2106+
# TFLite inputs don't have port numbers
2107+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
2108+
results = self.run_backend(g, g.outputs, feed_dict)
2109+
np.testing.assert_allclose(x_val, np.sort(results[0], axis=0))
2110+
20962111
def test_randomuniform_int(self):
20972112
def func():
20982113
shape = tf.constant([100, 3], name="shape")

tf2onnx/onnx_opset/generator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,28 @@ class PassThroughOp:
109109
def version_1(cls, ctx, node, **kwargs):
110110
pass
111111

112+
@tf_op(["RandomShuffle"])
113+
class RandomShuffleOp:
114+
@classmethod
115+
def version_10(cls, ctx, node, **kwargs):
116+
inp_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
117+
dim_0 = GraphBuilder(ctx).make_slice({'data': inp_shape, 'starts': [0], 'ends': [1], 'axes': [0]})
118+
zeros = ctx.make_node("ConstantOfShape", [dim_0], shapes=[[-1]]).output[0]
119+
120+
seed = node.get_attr_value("seed", 0)
121+
seed2 = node.get_attr_value("seed2", 0)
122+
onnx_seed = utils.combine_seeds(seed, seed2)
123+
rand_attr = {'dtype': onnx_pb.TensorProto.FLOAT}
124+
if onnx_seed is not None:
125+
rand_attr['seed'] = onnx_seed
126+
127+
random_floats = ctx.make_node("RandomUniformLike", [zeros], op_name_scope=node.name, shapes=[[-1]],
128+
attr=rand_attr).output[0]
129+
# Use indices of the TopK to get a random ordering
130+
_, random_ordering = ctx.make_node("TopK", [random_floats, dim_0], output_count=2, attr={'axis': -1}).output
131+
shuffled_res = ctx.make_node("Gather", [node.input[0], random_ordering]).output[0]
132+
ctx.replace_all_inputs(node.output[0], shuffled_res)
133+
112134
@tf_op("Fill")
113135
class Fill:
114136
@classmethod

tf2onnx/onnx_opset/nn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -990,11 +990,10 @@ def version_9(cls, ctx, node, **kwargs):
990990

991991
seed = node.get_attr_value("seed", 0)
992992
seed2 = node.get_attr_value("seed2", 0)
993+
onnx_seed = utils.combine_seeds(seed, seed2)
993994
rand_attr = {}
994-
if seed != 0 or seed2 != 0:
995-
# Produce a unique value depending on both seeds. (diagonal grid traversal)
996-
combined_seed = (seed + seed2 + 1) * (seed + seed2 + 2) // 2 - seed
997-
rand_attr['seed'] = float(combined_seed)
995+
if onnx_seed is not None:
996+
rand_attr['seed'] = onnx_seed
998997

999998
min_aspect_ratio, max_aspect_ratio = node.get_attr_value("aspect_ratio_range", [0.75, 1.33])
1000999
ratio_range = max_aspect_ratio - min_aspect_ratio

tf2onnx/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ def make_sure(bool_val, error_msg, *args):
259259
if not bool_val:
260260
raise ValueError("make_sure failure: " + error_msg % args)
261261

262+
def combine_seeds(seed, seed2):
263+
"""Produces an onnx float seed from two tf int seeds. Returns None if both seeds are 0."""
264+
if seed != 0 or seed2 != 0:
265+
# Produce a unique value depending on both seeds. (diagonal grid traversal)
266+
combined_seed = (seed + seed2 + 1) * (seed + seed2 + 2) // 2 - seed
267+
return float(combined_seed)
268+
return None
262269

263270
def topological_sort(dependencies):
264271
"""

0 commit comments

Comments
 (0)