Skip to content

Commit a2681da

Browse files
committed
support dynamic input shape
1 parent 9bd265a commit a2681da

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tf2onnx/onnx_opset/generator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,16 @@ def version_1(cls, ctx, node, **kwargs):
3939
seed = node.get_attr("seed")
4040
node.set_attr("seed", float(seed.f))
4141
if len(node.input) > 0:
42-
shape = node.inputs[0].get_tensor_value()
43-
ctx.remove_input(node, node.input[0])
44-
node.set_attr("shape", shape)
45-
ctx.set_shape(node.output[0], shape)
42+
if node.inputs[0].is_const():
43+
shape = node.inputs[0].get_tensor_value()
44+
ctx.remove_input(node, node.input[0])
45+
node.set_attr("shape", shape)
46+
ctx.set_shape(node.output[0], shape)
47+
else:
48+
cast_node = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
49+
const_node = ctx.make_node("ConstantOfShape", cast_node.output)
50+
node.input = const_node.output
51+
node.type = node.type + 'Like'
4652

4753

4854
@tf_op(["RandomNormalLike", "RandomUniformLike"])

0 commit comments

Comments
 (0)