Skip to content

Commit 4bba41c

Browse files
authored
Merge pull request #966 from jignparm/jignparm/fix_conv2dbackpropinput_shape
Set output shape of ConvTranspose (Conv2DBackpropInput) correctly
2 parents 5558957 + 84aae67 commit 4bba41c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def version_1(cls, ctx, node, **kwargs):
248248
# Note: inputs are reversed from what one would expect.
249249
conv_kernel_shape(ctx, node, 1)
250250
input_shape = ctx.get_shape(node.input[2])
251+
output_shape_orig = node.output_shapes
251252

252253
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
253254
if node.inputs[0].is_const():
@@ -285,7 +286,8 @@ def version_1(cls, ctx, node, **kwargs):
285286
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"),
286287
np.array([1, 2], dtype=np.int64))
287288
slice_node = ctx.make_node("Slice",
288-
[node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]])
289+
[node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]],
290+
shapes=output_shape_orig)
289291
downstream_nodes = ctx.find_output_consumers(node.output[0])
290292
downstream_nodes.remove(output_shape)
291293
downstream_nodes.remove(slice_node)

0 commit comments

Comments
 (0)