Skip to content

Commit 27f5e67

Browse files
authored
Merge pull request #471 from onnx/gs/fix-random-uniform
fix dynamic shape in tf.random_uniform for some cases
2 parents d03e469 + 8d29126 commit 27f5e67

File tree

2 files changed

+67
-14
lines changed

2 files changed

+67
-14
lines changed

tests/test_backend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,33 @@ def test_randomuniform_int(self):
11811181
# since results are random, compare the shapes only
11821182
self._run_test_case([_OUTPUT], {}, check_value=False, check_shape=True)
11831183

1184+
@skip_caffe2_backend()
1185+
def test_randomuniform_dyn_shape(self):
1186+
# test for dynamic shape coming from a shape op
1187+
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
1188+
x = tf.placeholder(x_val.dtype, name=_TFINPUT)
1189+
x_ = tf.stack([x, x])
1190+
x_ = tf.identity(x_)
1191+
x_ = tf.shape(x_, name="shape")
1192+
x_ = tf.random_uniform(x_, name="rand", dtype=tf.float32)
1193+
x_ = tf.identity(x_)
1194+
_ = tf.identity(x_, name=_TFOUTPUT)
1195+
# since results are random, compare the shapes only
1196+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)
1197+
1198+
@skip_caffe2_backend()
1199+
def test_randomuniform_calc_shape(self):
1200+
# test for dynamic shape coming from some subgraph
1201+
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
1202+
x = tf.placeholder(x_val.dtype, [None, 3], name=_TFINPUT)
1203+
x_ = tf.identity(x)
1204+
x_ = tf.shape(x_, name="shape")[1:]
1205+
x_ = tf.random_uniform(x_, name="rand", dtype=tf.float32)
1206+
x_ = tf.identity(x_)
1207+
_ = tf.identity(x_, name=_TFOUTPUT)
1208+
# since results are random, compare the shapes only
1209+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)
1210+
11841211
@skip_caffe2_backend()
11851212
def test_argminmax(self):
11861213
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))

tf2onnx/rewriter/random_uniform.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"""
55
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx random_uniform op
66
"""
7+
import numpy as np
78
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
8-
from tf2onnx import utils
9+
from tf2onnx import utils, handler
910

1011

1112
# pylint: disable=missing-docstring
@@ -29,10 +30,10 @@ def rewrite_random_uniform(g, ops):
2930
# max is on input 0
3031
tmax = input2.inputs[0].get_tensor_value()
3132
tmin = input2.inputs[1].get_tensor_value()
32-
33-
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output)
33+
to_delete = list(set(match.get_nodes()))
34+
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
3435
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
35-
for n in set(match.get_nodes()):
36+
for n in to_delete:
3637
g.remove_node(n.name)
3738

3839
return ops
@@ -59,25 +60,50 @@ def rewrite_random_uniform_fold_const(g, ops):
5960
tmax_minus_tmin = mul.inputs[1].get_tensor_value()
6061
tmin = output.inputs[1].get_tensor_value()
6162
tmax = tmin + tmax_minus_tmin
62-
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output)
63+
to_delete = list(set(match.get_nodes()))
64+
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
6365
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
64-
for n in set(match.get_nodes()):
66+
for n in to_delete:
6567
g.remove_node(n.name)
6668

6769
return ops
6870

6971

70-
def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output):
72+
def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
7173
dtype = g.get_dtype(output.output[0])
7274
op_name = utils.make_name("RandomUniform")
73-
if ru_op.inputs[0].type == "Shape":
74-
shape_node = ru_op.inputs[0]
75-
new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name,
76-
attr={"low": tmin, "high": tmax, "dtype": dtype},
77-
shapes=shape_node.output_shapes, dtypes=[dtype])
78-
else:
79-
shape = g.get_shape(output.output[0])
75+
shape_node = ru_op.inputs[0]
76+
shape = g.get_shape(output.output[0])
77+
if shape_node.is_const():
78+
# if the tensorflow input (aka the shape) is const we can use the RandomUniform op
8079
new_node = g.make_node("RandomUniform", [], name=op_name,
8180
attr={"low": tmin, "high": tmax, "dtype": dtype, "shape": shape},
8281
shapes=[shape], dtypes=[dtype])
82+
else:
83+
if shape_node.type == "Shape":
84+
# if shape is dynamic - in tensorflow shape comes as tensor VALUE,
85+
# in onnx RandomUniformLike finds takes the shape from the tensor itself.
86+
# In many cases there is a shape op in tensorflow before RandomUniform and
87+
# to make that work for onnx we just need to remove the shape op.
88+
new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name,
89+
attr={"low": tmin, "high": tmax, "dtype": dtype},
90+
shapes=shape, dtypes=[dtype])
91+
else:
92+
# if the shape is calculated we need to create a tensor so RandomUniformLike
93+
# can take the shape from there. Pre opset9 this is somewhat hacky because there is
94+
# no real fill op in onnx. In general this is not going to help performance but the tensors
95+
# created are expected to be small.
96+
97+
# tell the caller to not delete the shape node
98+
to_delete.remove(shape_node)
99+
# create a fill op with the shape of the value of the input tensor
100+
zero = g.make_const(utils.make_name("zero"), np.zeros((), dtype=np.float32))
101+
fill_node = g.make_node("Fill", inputs=[shape_node.output[0], zero.name],
102+
shapes=shape, dtypes=[dtype])
103+
func, _ = handler.tf_op.find_effective_op("Fill")
104+
func(g, fill_node)
105+
# and use RandomUniformLike to create the random tensor
106+
new_node = g.make_node("RandomUniformLike", inputs=[fill_node.output[0]], name=op_name,
107+
attr={"low": tmin, "high": tmax, "dtype": dtype},
108+
shapes=shape, dtypes=[dtype])
83109
return new_node

0 commit comments

Comments
 (0)