Skip to content

Commit c1eb6a8

Browse files
authored
Merge pull request #949 from jignparm/jignparm/fix_randomnormallike
Add new pattern for RandomStandardNormal op in TF2
2 parents f5400e9 + 66afe9c commit c1eb6a8

File tree

1 file changed

+38
-23
lines changed

1 file changed

+38
-23
lines changed

tf2onnx/rewriter/random_normal_rewriter.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,48 @@
1313

1414

1515
def rewrite_random_normal(g, ops):
16-
pattern = \
16+
pattern1 = \
1717
OpTypePattern('Add', name='output', inputs=[
1818
OpTypePattern('Mul', name='input2', inputs=[
1919
OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
2020
]), "*"
2121
])
2222

23-
matcher = GraphMatcher(pattern)
24-
match_results = list(matcher.match_ops(ops))
25-
for match in match_results:
26-
output = match.get_op('output')
27-
mean = output.inputs[1].get_tensor_value()
28-
dtype = g.get_dtype(output.output[0])
29-
op_name = utils.make_name("RandomNormal")
30-
out_name = utils.port_name(op_name)
31-
32-
rn_op = match.get_op('input1')
33-
seed = rn_op.get_attr('seed2').i
34-
if rn_op.inputs[0].type == "Shape":
35-
shape_node = rn_op.inputs[0]
36-
new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
37-
attr={"mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
38-
else:
39-
shape = g.get_shape(output.output[0])
40-
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
41-
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
42-
43-
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
44-
g.safe_remove_nodes(match.get_nodes())
23+
pattern2 = \
24+
OpTypePattern('Identity', name='output', inputs=[
25+
OpTypePattern('Identity', name='input2', inputs=[
26+
OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"])
27+
])
28+
])
29+
30+
pattern_list = [pattern1, pattern2]
31+
for pattern in pattern_list:
32+
matcher = GraphMatcher(pattern)
33+
match_results = list(matcher.match_ops(ops))
34+
for match in match_results:
35+
output = match.get_op('output')
36+
if output.type == 'Add':
37+
# pattern 1
38+
mean = output.inputs[1].get_tensor_value()
39+
else:
40+
# pattern 2
41+
mean = 0.0
42+
dtype = g.get_dtype(output.output[0])
43+
op_name = utils.make_name("RandomNormal")
44+
out_name = utils.port_name(op_name)
45+
46+
rn_op = match.get_op('input1')
47+
seed = rn_op.get_attr('seed2').i
48+
49+
if rn_op.inputs[0].type == "Shape":
50+
shape_node = rn_op.inputs[0]
51+
new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
52+
attr={"mean": mean, "scale": 1.0, "dtype": dtype, "seed": float(seed)})
53+
else:
54+
shape = g.get_shape(output.output[0])
55+
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
56+
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
57+
58+
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
59+
g.safe_remove_nodes(match.get_nodes())
4560
return ops

0 commit comments

Comments
 (0)