Skip to content

Commit aeaaa10

Browse files
Merge pull request #869 from RandySheriffH/rashuai/DynamicRandom
Rashuai/dynamic random
2 parents 9bd265a + cf50499 commit aeaaa10

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

tests/test_backend.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,14 +1374,9 @@ def func():
13741374
@skip_caffe2_backend()
13751375
def test_randomuniform_dyn_shape(self):
13761376
# test for dynamic shape coming from a shape op
1377-
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
1377+
x_val = np.array([0, 1, 2, 3, 5], dtype=np.int64)
13781378
def func(x):
1379-
x_ = tf.stack([x, x])
1380-
x_ = tf.identity(x_)
1381-
x_ = tf.shape(x_, name="shape")
1382-
x_ = random_uniform(x_, name="rand", dtype=tf.float32)
1383-
x_ = tf.identity(x_)
1384-
return tf.identity(x_, name=_TFOUTPUT)
1379+
return random_uniform(x[3:], name=_TFOUTPUT, dtype=tf.float32)
13851380
# since results are random, compare the shapes only
13861381
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)
13871382

tf2onnx/onnx_opset/generator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def version_1(cls, ctx, node, **kwargs):
4444
node.set_attr("shape", shape)
4545
ctx.set_shape(node.output[0], shape)
4646

47+
@classmethod
48+
def version_9(cls, ctx, node, **kwargs):
49+
if node.inputs[0].is_const():
50+
cls.version_1(ctx, node, **kwargs)
51+
else:
52+
seed = node.get_attr("seed")
53+
node.set_attr("seed", float(seed.f))
54+
cast_node = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
55+
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
56+
node.input = const_node.output
57+
node.type = node.type + 'Like'
58+
4759

4860
@tf_op(["RandomNormalLike", "RandomUniformLike"])
4961
class PassThroughOp:

0 commit comments

Comments
 (0)