Skip to content

Commit 5d2b73c

Browse files
Convert RandomUniformInt (#1347)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent ad4c792 commit 5d2b73c

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

tests/test_backend.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,16 +1800,61 @@ def func():
18001800
# since results are random, compare the shapes only
18011801
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
18021802

1803-
@unittest.skip("TF RandomUniformInt is not supported")
18041803
def test_randomuniform_int(self):
18051804
def func():
1806-
shape = tf.constant([2, 3], name="shape")
1807-
x_ = random_uniform(shape, name="rand", dtype=tf.int32, maxval=10)
1805+
shape = tf.constant([100, 3], name="shape")
1806+
x_ = random_uniform(shape, name="rand", dtype=tf.int32, minval=2, maxval=10)
18081807
x_ = tf.identity(x_, name="output1")
18091808
x_ = tf.identity(x_, name="output2")
18101809
return tf.identity(x_, name=_TFOUTPUT)
18111810
# since results are random, compare the shapes only
1812-
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1811+
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1812+
results = self.run_backend(g, [_OUTPUT], {})
1813+
numbers = set(results[0].flatten())
1814+
self.assertEqual(sorted(numbers), list(range(2, 10)))
1815+
1816+
def test_randomuniform_int_nonconst_max(self):
1817+
m_val = np.array(8, dtype=np.int32)
1818+
def func(m):
1819+
shape = tf.constant([100, 3], name="shape")
1820+
x_ = random_uniform(shape, name="rand", dtype=tf.int32, minval=0, maxval=m)
1821+
x_ = tf.identity(x_, name="output1")
1822+
x_ = tf.identity(x_, name="output2")
1823+
return tf.identity(x_, name=_TFOUTPUT)
1824+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: m_val}, check_value=False, check_shape=True)
1825+
results = self.run_backend(g, [_OUTPUT], {_INPUT: m_val})
1826+
numbers = set(results[0].flatten())
1827+
self.assertEqual(sorted(numbers), list(range(8)))
1828+
1829+
def test_randomuniform_int_nonconst_min_max(self):
1830+
n_val = np.array(2, dtype=np.int32)
1831+
m_val = np.array(10, dtype=np.int32)
1832+
def func(n, m):
1833+
shape = tf.constant([100, 3], name="shape")
1834+
x_ = random_uniform(shape, name="rand", dtype=tf.int32, minval=n, maxval=m)
1835+
x_ = tf.identity(x_, name="output1")
1836+
x_ = tf.identity(x_, name="output2")
1837+
return tf.identity(x_, name=_TFOUTPUT)
1838+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val}, check_value=False, check_shape=True)
1839+
results = self.run_backend(g, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val})
1840+
numbers = set(results[0].flatten())
1841+
self.assertEqual(sorted(numbers), list(range(2, 10)))
1842+
1843+
@check_opset_min_version(9, "RandomUniformLike")
1844+
def test_randomuniform_int_nonconst_min_max_shape(self):
1845+
n_val = np.array(2, dtype=np.int32)
1846+
m_val = np.array(10, dtype=np.int32)
1847+
s_val = np.array([100, 3], dtype=np.int64)
1848+
def func(n, m, s):
1849+
x_ = random_uniform(s, name="rand", dtype=tf.int32, minval=n, maxval=m)
1850+
x_ = tf.identity(x_, name="output1")
1851+
x_ = tf.identity(x_, name="output2")
1852+
return tf.identity(x_, name=_TFOUTPUT)
1853+
g = self._run_test_case(func, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val},
1854+
check_value=False, check_shape=True)
1855+
results = self.run_backend(g, [_OUTPUT], {_INPUT: n_val, _INPUT1: m_val, _INPUT2: s_val})
1856+
numbers = set(results[0].flatten())
1857+
self.assertEqual(sorted(numbers), list(range(2, 10)))
18131858

18141859
@skip_caffe2_backend()
18151860
@check_opset_after_tf_version("2.2", 9, "RandomUniform")

tf2onnx/onnx_opset/generator.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,33 @@ def version_1(cls, ctx, node, **kwargs):
2828
pass
2929

3030

31-
@tf_op(["RandomNormal", "RandomUniform"])
31+
@tf_op(["RandomNormal", "RandomUniform", "RandomUniformInt"])
3232
class RandomOp:
33+
@classmethod
34+
def randuniform_int(cls, ctx, node, min_inp, max_inp):
35+
dtype = ctx.get_dtype(node.output[0])
36+
min_node = ctx.get_node_by_output(min_inp)
37+
max_node = ctx.get_node_by_output(max_inp)
38+
if min_node.is_const() and max_node.is_const():
39+
node.set_attr('low', float(min_node.get_tensor_value()))
40+
node.set_attr('high', float(max_node.get_tensor_value()))
41+
out = node.output[0]
42+
elif min_node.is_const() and min_node.get_tensor_value() == 0:
43+
max_float = ctx.make_node("Cast", [max_inp], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
44+
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], inputs=[node.output[0], max_float])
45+
out = mul_node.output[0]
46+
else:
47+
min_float = ctx.make_node("Cast", [min_inp], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
48+
max_float = ctx.make_node("Cast", [max_inp], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
49+
diff = ctx.make_node("Sub", [max_float, min_float]).output[0]
50+
diff_float = ctx.make_node("Cast", [diff], attr={'to': onnx_pb.TensorProto.FLOAT}).output[0]
51+
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], inputs=[node.output[0], diff_float])
52+
mul = mul_node.output[0]
53+
add_node = ctx.insert_new_node_on_output("Add", mul, inputs=[mul, min_float])
54+
out = add_node.output[0]
55+
floor_node = ctx.insert_new_node_on_output("Floor", out)
56+
ctx.insert_new_node_on_output("Cast", floor_node.output[0], to=dtype)
57+
3358
@classmethod
3459
def version_1(cls, ctx, node, **kwargs):
3560
# in tf-2.0 grappler optimizes the graph pretty well and our matching logic
@@ -43,6 +68,10 @@ def version_1(cls, ctx, node, **kwargs):
4368
ctx.remove_input(node, node.input[0], 0)
4469
node.set_attr("shape", shape)
4570
ctx.set_shape(node.output[0], shape)
71+
if node.type == "RandomUniformInt":
72+
cls.randuniform_int(ctx, node, node.input[0], node.input[1])
73+
node.type = "RandomUniform"
74+
ctx.replace_inputs(node, [])
4675

4776
@classmethod
4877
def version_9(cls, ctx, node, **kwargs):
@@ -51,10 +80,15 @@ def version_9(cls, ctx, node, **kwargs):
5180
else:
5281
seed = node.get_attr("seed")
5382
node.set_attr("seed", float(seed.f))
54-
cast_node = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
83+
cast_node = ctx.make_node("Cast", [node.input[0]], attr={'to': onnx_pb.TensorProto.INT64})
5584
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
85+
inputs = node.input.copy()
5686
ctx.replace_inputs(node, const_node.output.copy())
57-
node.type = node.type + 'Like'
87+
if node.type == "RandomUniformInt":
88+
cls.randuniform_int(ctx, node, inputs[1], inputs[2])
89+
node.type = "RandomUniformLike"
90+
else:
91+
node.type = node.type + 'Like'
5892

5993

6094
@tf_op(["RandomNormalLike", "RandomUniformLike"])

tf2onnx/tf_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def is_huge_shape(x):
221221
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
222222
progress = True
223223
can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault']
224+
can_fold = can_fold and not node.type.startswith('Random')
224225
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
225226
# We can only fold nodes with a single output
226227
can_fold = can_fold and len(output_names) == 1 and output_names[0] not in outputs_to_values

0 commit comments

Comments
 (0)