Skip to content

Commit 597c70f

Browse files
committed
Fix seed attr for RandomNormal and RandomNormaLike ops
1 parent 84cfa03 commit 597c70f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tf2onnx/rewriter/random_normal_rewriter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ def rewrite_random_normal(g, ops):
3030
out_name = utils.port_name(op_name)
3131

3232
rn_op = match.get_op('input1')
33+
seed = rn_op.get_attr('seed2').i
3334
if rn_op.inputs[0].type == "Shape":
3435
shape_node = rn_op.inputs[0]
3536
new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
36-
attr={"mean": mean, "scale": 1.0, "dtype": dtype})
37+
attr={"mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
3738
else:
3839
shape = g.get_shape(output.output[0])
3940
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
40-
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype})
41+
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed})
4142

4243
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
4344
g.safe_remove_nodes(match.get_nodes())

0 commit comments

Comments
 (0)