Skip to content

Commit 5e0c604

Browse files
committed
Correct UT
1 parent a2681da commit 5e0c604

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

tests/test_backend.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,18 +1370,13 @@ def func():
13701370
return tf.identity(x_, name=_TFOUTPUT)
13711371
# since results are random, compare the shapes only
13721372
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1373-
1373+
'''
13741374
@skip_caffe2_backend()
13751375
def test_randomuniform_dyn_shape(self):
13761376
# test for dynamic shape coming from a shape op
1377-
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
1377+
x_val = np.array([0,1,2,3,5], dtype=np.int64)
13781378
def func(x):
1379-
x_ = tf.stack([x, x])
1380-
x_ = tf.identity(x_)
1381-
x_ = tf.shape(x_, name="shape")
1382-
x_ = random_uniform(x_, name="rand", dtype=tf.float32)
1383-
x_ = tf.identity(x_)
1384-
return tf.identity(x_, name=_TFOUTPUT)
1379+
return random_uniform(x[3:], name=_TFOUTPUT, dtype=tf.float32)
13851380
# since results are random, compare the shapes only
13861381
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)
13871382

tf2onnx/onnx_opset/generator.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,22 @@ def version_1(cls, ctx, node, **kwargs):
3939
seed = node.get_attr("seed")
4040
node.set_attr("seed", float(seed.f))
4141
if len(node.input) > 0:
42-
if node.inputs[0].is_const():
43-
shape = node.inputs[0].get_tensor_value()
44-
ctx.remove_input(node, node.input[0])
45-
node.set_attr("shape", shape)
46-
ctx.set_shape(node.output[0], shape)
47-
else:
48-
cast_node = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
49-
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
50-
node.input = const_node.output
51-
node.type = node.type + 'Like'
42+
shape = node.inputs[0].get_tensor_value()
43+
ctx.remove_input(node, node.input[0])
44+
node.set_attr("shape", shape)
45+
ctx.set_shape(node.output[0], shape)
46+
47+
@classmethod
48+
def version_9(cls, ctx, node, **kwargs):
49+
if node.inputs[0].is_const():
50+
version_1(cls, ctx, node, **kwargs)
51+
else:
52+
seed = node.get_attr("seed")
53+
node.set_attr("seed", float(seed.f))
54+
cast_node = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
55+
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
56+
node.input = const_node.output
57+
node.type = node.type + 'Like'
5258

5359

5460
@tf_op(["RandomNormalLike", "RandomUniformLike"])

0 commit comments

Comments
 (0)