Skip to content

Commit c33b4e6

Browse files
lei-Qiaowayuanho
authored andcommitted
Add checker when setting convtranspose’s attribute (#603)
* Add checker when setting convtranspose’s attribute
1 parent da2d7a6 commit c33b4e6

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,21 @@ def version_1(cls, ctx, node, **kwargs):
220220
node.type = "ConvTranspose"
221221
# Note: inputs are reversed from what one would expect.
222222
kernel_shape = conv_kernel_shape(ctx, node, 1)
223+
input_shape = ctx.get_shape(node.input[2])
223224

224225
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
225226
output_shape = ctx.get_shape(node.output[0])
226227
if node.is_nhwc():
227228
new_output_shape = [output_shape[1], output_shape[2]]
229+
input_hw = [input_shape[1], input_shape[2]]
228230
else:
229231
new_output_shape = [output_shape[2], output_shape[3]]
232+
input_hw = [input_shape[2], input_shape[3]]
233+
234+
utils.make_sure(new_output_shape.count(-1) <= 0, "output h and w need to be known")
235+
utils.make_sure(new_output_shape[0] >= input_hw[0] and new_output_shape[1] >= input_hw[1],
236+
"output h and w cannot be smaller than input h and w.")
237+
230238
node.set_attr("output_shape", new_output_shape)
231239

232240
strides = conv_dims_attr(node, "strides")

0 commit comments

Comments
 (0)