@@ -463,7 +463,7 @@ def conv_dims_attr(node, name, new_name=None):
463
463
464
464
465
465
def conv_kernel_shape (ctx , node , input_idx , spatial = 2 ):
466
- kernel_shape = ctx .get_shape (node .input [1 ])
466
+ kernel_shape = ctx .get_shape (node .input [input_idx ])
467
467
if len (kernel_shape ) != 2 * spatial :
468
468
raise ValueError ("kernel rank must be 2* spatial" )
469
469
kernel_shape = kernel_shape [0 :spatial ]
@@ -492,6 +492,8 @@ def convtranspose_op(ctx, node, name, args):
492
492
493
493
# Note: inputs are reversed from what one would expect.
494
494
kernel_shape = conv_kernel_shape (ctx , node , 1 )
495
+
496
+ # ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
495
497
output_shape = node .inputs [0 ].get_tensor_value ()
496
498
if node .is_nhwc ():
497
499
new_output_shape = [output_shape [1 ], output_shape [2 ]]
@@ -501,15 +503,17 @@ def convtranspose_op(ctx, node, name, args):
501
503
502
504
strides = conv_dims_attr (node , "strides" )
503
505
conv_dims_attr (node , "dilations" )
504
- add_padding (ctx , node , kernel_shape , strides )
505
506
506
- # remove output_shapes input, swap data and kernel
507
+ # remove output_shapes input
507
508
ctx .remove_input (node , node .input [0 ])
509
+ # swap data and kernel
508
510
t = node .input [0 ]
509
511
node .input [0 ] = node .input [1 ]
510
512
node .input [1 ] = t
511
513
512
514
nodes = conv_convert_inputs (ctx , node , with_kernel = True )
515
+
516
+ # Note: output_padding, group are left default.
513
517
return nodes
514
518
515
519
0 commit comments