Skip to content

Commit 9713ffd

Browse files
authored
Merge pull request #473 from lucienwang1009/bug
fix create_onnx_random_uniform_op shape
2 parents 7538f59 + 3a349d9 commit 9713ffd

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tf2onnx/rewriter/random_uniform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
8787
# to make that work for onnx we just need to remove the shape op.
8888
new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name,
8989
attr={"low": tmin, "high": tmax, "dtype": dtype},
90-
shapes=shape, dtypes=[dtype])
90+
shapes=[shape], dtypes=[dtype])
9191
else:
9292
# if the shape is calculated we need to create a tensor so RandomUniformLike
9393
# can take the shape from there. Pre opset9 this is somewhat hacky because there is
@@ -99,11 +99,11 @@ def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
9999
# create a fill op with the shape of the value of the input tensor
100100
zero = g.make_const(utils.make_name("zero"), np.zeros((), dtype=np.float32))
101101
fill_node = g.make_node("Fill", inputs=[shape_node.output[0], zero.name],
102-
shapes=shape, dtypes=[dtype])
102+
shapes=[shape], dtypes=[dtype])
103103
func, _ = handler.tf_op.find_effective_op("Fill")
104104
func(g, fill_node)
105105
# and use RandomUniformLike to create the random tensor
106106
new_node = g.make_node("RandomUniformLike", inputs=[fill_node.output[0]], name=op_name,
107107
attr={"low": tmin, "high": tmax, "dtype": dtype},
108-
shapes=shape, dtypes=[dtype])
108+
shapes=[shape], dtypes=[dtype])
109109
return new_node

0 commit comments

Comments
 (0)