|
15 | 15 | from .Dense import _activation_map |
16 | 16 |
|
17 | 17 |
|
| 18 | +def _calc_explicit_padding(input_size, output_shape, output_padding, kernel_shape, stride, dilation, perm): |
| 19 | + to_nchw = lambda x, perm: [x[perm[n_]] for n_ in range(len(x))] |
| 20 | + input_size = to_nchw(input_size, perm)[2:] |
| 21 | + output_shape = to_nchw(output_shape, perm)[2:] |
| 22 | + |
| 23 | + spatial = len(kernel_shape) |
| 24 | + total_padding = [] |
| 25 | + pads = [None] * 2 * spatial |
| 26 | + for i in range(spatial): |
| 27 | + total_padding[i:] = [stride[i] * (output_shape[i] - 1) + |
| 28 | + output_padding[i] + kernel_shape[i] * dilation[i] - input_size[i]] |
| 29 | + pads[i] = total_padding[i] // 2 |
| 30 | + pads[i + spatial] = total_padding[i] - (total_padding[i] // 2) |
| 31 | + |
| 32 | + return pads |
| 33 | + |
| 34 | + |
18 | 35 | def convert_keras_conv_core(scope, operator, container, is_transpose, n_dims, input_perm_axes, |
19 | 36 | output_perm_axes, weight_perm_axes): |
20 | 37 | op = operator.raw_operator |
@@ -69,22 +86,27 @@ def convert_keras_conv_core(scope, operator, container, is_transpose, n_dims, in |
69 | 86 | attrs['dilations'] = list(op.dilation_rate) |
70 | 87 | attrs['strides'] = list(op.strides) |
71 | 88 | attrs['kernel_shape'] = op.kernel_size |
72 | | - # Fix this... |
73 | 89 | attrs['group'] = group |
74 | 90 |
|
75 | 91 | if op.padding == 'valid': |
76 | 92 | attrs['auto_pad'] = 'VALID' |
77 | 93 | elif op.padding == 'same': |
78 | | - if is_transpose: # bypass onnx engine issue on convtranpose support. |
79 | | - attrs['auto_pad'] = 'SAME_LOWER' |
80 | | - shape = [-1 if i is None else i for i in op.output_shape] |
81 | | - if channels_first: |
82 | | - attrs['output_shape'] = shape |
| 94 | + if op.input_shape.count(None) > 1: |
| 95 | + if is_transpose: |
| 96 | + attrs['auto_pad'] = 'SAME_LOWER' # the controversial def in onnx spec. |
83 | 97 | else: |
84 | | - attrs['output_shape'] = shape[0:1] + shape[-1:] + shape[1:-1] |
85 | | - |
| 98 | + attrs['auto_pad'] = 'SAME_UPPER' |
86 | 99 | else: |
87 | | - attrs['auto_pad'] = 'SAME_LOWER' |
| 100 | + output_padding = [0] * len(op.kernel_size) |
| 101 | + if hasattr(op, 'output_padding') and op.output_padding is not None: |
| 102 | + output_padding = op.output_padding |
| 103 | + attrs['pads'] = _calc_explicit_padding(op.output_shape if is_transpose else op.input_shape, |
| 104 | + op.input_shape if is_transpose else op.output_shape, |
| 105 | + output_padding, |
| 106 | + op.kernel_size, |
| 107 | + op.strides, |
| 108 | + op.dilation_rate, |
| 109 | + list(range(len(op.input_shape))) if channels_first else input_perm_axes) |
88 | 110 | else: |
89 | 111 | raise RuntimeError("Unsupported padding type '{}'".format(op.padding)) |
90 | 112 |
|
|
0 commit comments