Skip to content

Commit 98e6143

Browse files
Implement RandomStandardNormal conversion (#1484)
* Fix bug in rand_norm_rewriter for unknown shapes Signed-off-by: Tom Wildenhain <[email protected]> * Implement RandomStandardNormal Signed-off-by: Tom Wildenhain <[email protected]>
1 parent a6b141b commit 98e6143

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

tests/test_backend.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,44 @@ def func():
19621962
# since results are random, compare the shapes only
19631963
self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
19641964

1965+
def test_random_std_normal(self):
1966+
def func():
1967+
shape = tf.constant([20, 10, 50], name="shape")
1968+
x_ = tf.random.normal(shape)
1969+
return tf.identity(x_, name=_TFOUTPUT)
1970+
# since results are random, compare the shapes only
1971+
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1972+
results = self.run_backend(g, g.outputs, {})[0]
1973+
self.assertTrue(-0.1 < np.mean(results) < 0.1)
1974+
self.assertTrue(0.9 < np.std(results) < 1.1)
1975+
1976+
def test_randomnormal(self):
1977+
def func():
1978+
shape = tf.constant([20, 10, 50], name="shape")
1979+
x_ = tf.random.normal(shape, mean=10, stddev=2)
1980+
return tf.identity(x_, name=_TFOUTPUT)
1981+
# since results are random, compare the shapes only
1982+
g = self._run_test_case(func, [_OUTPUT], {}, check_value=False, check_shape=True)
1983+
results = self.run_backend(g, g.outputs, {})[0]
1984+
self.assertTrue(9.8 < np.mean(results) < 10.2)
1985+
self.assertTrue(1.9 < np.std(results) < 2.1)
1986+
1987+
@check_opset_min_version(9, "RandomNormalLike")
1988+
def test_randomnormal_unknown_shape(self):
1989+
shape_val = np.array([20, 10, 50], np.int32)
1990+
def func(shape):
1991+
x_ = tf.random.normal(shape)
1992+
return tf.identity(x_, name=_TFOUTPUT)
1993+
# since results are random, compare the shapes only
1994+
feed_dict = {_INPUT: shape_val}
1995+
g = self._run_test_case(func, [_OUTPUT], feed_dict, check_value=False, check_shape=True)
1996+
if "input" in g.input_names:
1997+
# TFLite inputs don't have port numbers
1998+
feed_dict = {k.split(":")[0]: v for k, v in feed_dict.items()}
1999+
results = self.run_backend(g, g.outputs, feed_dict)[0]
2000+
self.assertTrue(-0.1 < np.mean(results) < 0.1)
2001+
self.assertTrue(0.9 < np.std(results) < 1.1)
2002+
19652003
def test_randomuniform_int(self):
19662004
def func():
19672005
shape = tf.constant([100, 3], name="shape")

tf2onnx/onnx_opset/generator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def version_1(cls, ctx, node, **kwargs):
2929
pass
3030

3131

32-
@tf_op(["RandomNormal", "RandomUniform", "RandomUniformInt"])
32+
@tf_op(["RandomNormal", "RandomUniform", "RandomUniformInt", "RandomStandardNormal"])
3333
class RandomOp:
3434
@classmethod
3535
def randuniform_int(cls, ctx, rand_node, rand_out, min_inp, max_inp):
@@ -66,7 +66,7 @@ def version_1(cls, ctx, node, **kwargs):
6666
# const we make it an attribute.
6767
seed = node.get_attr("seed")
6868
node.set_attr("seed", float(seed.f))
69-
utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9")
69+
utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9", node.type)
7070
shape = node.inputs[0].get_tensor_value()
7171
ctx.remove_input(node, node.input[0], 0)
7272
if len(shape) == 0:
@@ -84,6 +84,8 @@ def version_1(cls, ctx, node, **kwargs):
8484
cls.randuniform_int(ctx, node, rand_out, node.input[0], node.input[1])
8585
node.type = "RandomUniform"
8686
ctx.replace_inputs(node, [])
87+
elif node.type == "RandomStandardNormal":
88+
node.type = "RandomNormal"
8789

8890
@classmethod
8991
def version_9(cls, ctx, node, **kwargs):
@@ -99,6 +101,8 @@ def version_9(cls, ctx, node, **kwargs):
99101
if node.type == "RandomUniformInt":
100102
cls.randuniform_int(ctx, node, node.output[0], inputs[1], inputs[2])
101103
node.type = "RandomUniformLike"
104+
elif node.type == "RandomStandardNormal":
105+
node.type = "RandomNormalLike"
102106
else:
103107
node.type = node.type + 'Like'
104108

tf2onnx/rewriter/random_normal_rewriter.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,29 @@ def rewrite_random_normal(g, ops):
3939
else:
4040
# pattern 2
4141
mean = 0.0
42+
input2 = match.get_op('input2')
43+
if input2.type == 'Mul':
44+
scale = input2.inputs[1].get_tensor_value()
45+
else:
46+
scale = 1.0
4247
dtype = g.get_dtype(output.output[0])
4348
op_name = utils.make_name("RandomNormal")
4449
out_name = utils.port_name(op_name)
4550

4651
rn_op = match.get_op('input1')
47-
seed = rn_op.get_attr('seed2').i
52+
seed = float(rn_op.get_attr('seed2').i)
4853

54+
attr = {"mean": mean, "scale": scale, "dtype": dtype, "seed": seed}
4955
if rn_op.inputs[0].type == "Shape":
5056
shape_node = rn_op.inputs[0]
5157
new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
52-
attr={"mean": mean, "scale": 1.0, "dtype": dtype, "seed": float(seed)})
58+
attr=attr)
5359
else:
5460
shape = g.get_shape(output.output[0])
55-
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
56-
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
61+
if shape is None or -1 in shape:
62+
continue
63+
attr['shape'] = shape
64+
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, attr=attr)
5765

5866
g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
5967
g.safe_remove_nodes(match.get_nodes())

0 commit comments

Comments
 (0)