Skip to content

Commit b7a9835

Browse files
committed
fixes for conv padding
1 parent 05ee1a7 commit b7a9835

File tree

2 files changed

+123
-56
lines changed

2 files changed

+123
-56
lines changed

tests/test_backend.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import tempfile
88
import unittest
99
from collections import namedtuple
10+
from itertools import product
1011

1112
import numpy as np
1213
import tensorflow as tf
@@ -16,6 +17,10 @@
1617

1718
TMPPATH = tempfile.mkdtemp()
1819

20+
# we can override BACKEND and OPSET from the command line, but that is to late
21+
# to change the behavior of annotation. If need, pick the backend here.
22+
OPSET = 7
23+
1924
BACKEND = "caffe2"
2025
# BACKEND = "onnxmsrt"
2126
# BACKEND = "onnxmsrtnext"
@@ -37,7 +42,6 @@
3742
_OUTPUT = "output:0"
3843
_OUTPUT1 = "output1:0"
3944

40-
OPSET = 7
4145

4246
# pylint: disable=C0111
4347

@@ -47,6 +51,50 @@ def make_xval(shape):
4751
return x_val
4852

4953

54+
def get_conv_getdata(kind=1):
55+
if kind == 0:
56+
# generate all combinations (costly)
57+
dims = [
58+
("padding", ["SAME", "VALID"]),
59+
("input_sizes", [[32, 35, 35, 288], [32, 17, 17, 1248], [1, 28, 28, 3], [32, 8, 8, 2048]]),
60+
("filter_sizes", [[1, 3, 3, 1], [1, 2, 2, 1], [1, 5, 5, 1], [1, 1, 1, 1], [1, 5, 2, 1], [1, 2, 5, 1]]),
61+
("strides", [[1, 2, 2, 1], [1, 1, 1, 1]]),
62+
]
63+
values = [key_values[1] for key_values in dims]
64+
for idx, v in enumerate(product(*values)):
65+
if True or idx == 30:
66+
yield (idx,) + v
67+
elif kind == 1:
68+
# some combination to that give decent padding coverage
69+
data = [
70+
('SAME', [32, 35, 35, 288], [1, 3, 3, 1], [1, 2, 2, 1]),
71+
('SAME', [32, 35, 35, 288], [1, 2, 2, 1], [1, 2, 2, 1]),
72+
('SAME', [32, 35, 35, 288], [1, 2, 2, 1], [1, 1, 1, 1]),
73+
('SAME', [32, 35, 35, 288], [1, 5, 5, 1], [1, 1, 1, 1]),
74+
('SAME', [32, 35, 35, 288], [1, 1, 1, 1], [1, 2, 2, 1]),
75+
('SAME', [32, 35, 35, 288], [1, 1, 1, 1], [1, 1, 1, 1]),
76+
('SAME', [32, 35, 35, 288], [1, 5, 2, 1], [1, 2, 2, 1]),
77+
('SAME', [32, 35, 35, 288], [1, 2, 5, 1], [1, 2, 2, 1]),
78+
('SAME', [32, 35, 35, 288], [1, 2, 5, 1], [1, 1, 1, 1]),
79+
('SAME', [1, 28, 28, 3], [1, 3, 3, 1], [1, 2, 2, 1]),
80+
('SAME', [1, 28, 28, 3], [1, 3, 3, 1], [1, 1, 1, 1]),
81+
('SAME', [1, 28, 28, 3], [1, 2, 2, 1], [1, 2, 2, 1]),
82+
('SAME', [1, 28, 28, 3], [1, 2, 2, 1], [1, 1, 1, 1]),
83+
('SAME', [1, 28, 28, 3], [1, 5, 5, 1], [1, 2, 2, 1]),
84+
('SAME', [1, 28, 28, 3], [1, 5, 5, 1], [1, 1, 1, 1]),
85+
('SAME', [1, 28, 28, 3], [1, 5, 2, 1], [1, 2, 2, 1]),
86+
('SAME', [1, 28, 28, 3], [1, 2, 5, 1], [1, 1, 1, 1]),
87+
('SAME', [32, 8, 8, 2048], [1, 3, 3, 1], [1, 2, 2, 1]),
88+
('SAME', [32, 8, 8, 2048], [1, 3, 3, 1], [1, 1, 1, 1]),
89+
('VALID', [32, 35, 35, 288], [1, 3, 3, 1], [1, 1, 1, 1]),
90+
('VALID', [32, 35, 35, 288], [1, 2, 2, 1], [1, 2, 2, 1]),
91+
]
92+
for idx, v in enumerate(data):
93+
yield (idx,) + v
94+
else:
95+
raise ValueError("kind not known")
96+
97+
5098
class Tf2OnnxBackendTests(unittest.TestCase):
5199
def setUp(self):
52100
self.maxDiff = None
@@ -198,20 +246,26 @@ def test_multinomial1(self):
198246
self.assertEqual(expected.shape, actual.shape)
199247

200248
def test_maxppol(self):
201-
x_val = make_xval((1, 4, 4, 1))
202-
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
203-
mp = tf.nn.max_pool(x, [1, 2, 2, 1], _STRIDE1x1, padding="VALID")
204-
output = tf.identity(mp, name=_TFOUTPUT)
205-
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
206-
self.assertAllClose(expected, actual)
249+
for p in get_conv_getdata():
250+
idx, padding, x_shape, ksize, strides = p
251+
tf.reset_default_graph()
252+
x_val = make_xval(x_shape)
253+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
254+
mp = tf.nn.max_pool(x, ksize, strides, padding=padding)
255+
output = tf.identity(mp, name=_TFOUTPUT)
256+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
257+
self.assertAllClose(expected, actual, err_msg=str(p))
207258

208259
def test_avgppol(self):
209-
x_val = make_xval((1, 4, 4, 1))
210-
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
211-
mp = tf.nn.avg_pool(x, [1, 2, 2, 1], _STRIDE1x1, padding="VALID")
212-
output = tf.identity(mp, name=_TFOUTPUT)
213-
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
214-
self.assertAllClose(expected, actual)
260+
for p in get_conv_getdata(kind=0):
261+
idx, padding, x_shape, ksize, strides = p
262+
tf.reset_default_graph()
263+
x_val = make_xval(x_shape)
264+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
265+
mp = tf.nn.avg_pool(x, ksize, strides, padding=padding)
266+
output = tf.identity(mp, name=_TFOUTPUT)
267+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
268+
self.assertAllClose(expected, actual, err_msg=str(p))
215269

216270
def _conv_test(self, x_val, w, strides=None, padding="VALID", dilations=None):
217271
if strides is None:
@@ -753,13 +807,20 @@ def test_pow_scalar(self):
753807

754808
@unittest.skipIf(BACKEND == "caffe2", "not supported correctly in caffe2")
755809
def test_pad(self):
756-
x_val = np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np.float32)
757-
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
758-
paddings = tf.constant([[0, 0, ], [2, 0]])
759-
op = tf.pad(x, paddings, "CONSTANT")
760-
output = tf.identity(op, name=_TFOUTPUT)
761-
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
762-
self.assertAllClose(expected, actual)
810+
params = [
811+
("CONSTANT", [[1, 1], [2, 2]], [[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]),
812+
("CONSTANT", [[0, 0], [3, 3], [3, 3], [0, 0]], np.random.randn(1, 3, 4, 5).astype(np.float32)),
813+
]
814+
for p in params:
815+
tf.reset_default_graph()
816+
mode, pad, xv = p
817+
x_val = np.array(xv, dtype=np.float32)
818+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
819+
paddings = tf.constant(pad)
820+
op = tf.pad(x, paddings, mode)
821+
output = tf.identity(op, name=_TFOUTPUT)
822+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
823+
self.assertAllClose(expected, actual, err_msg=str(p))
763824

764825
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not supported correctly in caffe2")
765826
def test_randomuniform(self):

tf2onnx/tfonnx.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ def reshape_op5(ctx, node, name, args):
304304
HWCN_TO_NCHW = [3, 2, 0, 1]
305305
NCHW_TO_HWCN = [2, 3, 1, 0]
306306

307+
def spatial_map(shape, perm):
308+
new_shape = shape[:]
309+
for i in perm:
310+
new_shape[i] = shape[perm[i]]
311+
return new_shape
307312

308313
def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
309314
input_indices=None, output_indices=None):
@@ -320,11 +325,6 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
320325
new_kernel_shape: reshape the kernel
321326
"""
322327

323-
def calc_shape(a, b):
324-
if a and b:
325-
return [a[b[i]] for i in b]
326-
return None
327-
328328
if input_indices is None:
329329
input_indices = [0]
330330
if output_indices is None:
@@ -335,42 +335,44 @@ def calc_shape(a, b):
335335
if node.is_nhwc():
336336
# transpose input if needed, no need to record shapes on input
337337
for idx in input_indices:
338+
parent = node.inputs[idx]
338339
if node.inputs[idx].is_const():
339340
# if input is a constant, transpose that one
340-
parent = node.inputs[idx]
341341
if not parent.data_format:
342342
val = parent.get_tensor_value()
343343
parent.set_tensor_value(val.transpose(NHWC_TO_NCHW))
344-
parent.data_format = "NCHW"
345344
else:
346345
# if input comes from a op, insert transpose op
347346
input_name = node.input[idx]
348347
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
349348
transpose.set_attr("perm", NHWC_TO_NCHW)
350349
transpose.inserted_nchw = True
351-
if idx == 0:
352-
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(input_name), NHWC_TO_NCHW))
350+
shape = ctx.get_shape(input_name)
351+
new_shape = spatial_map(shape, NHWC_TO_NCHW)
352+
ctx.set_shape(transpose.output[0], new_shape)
353353
nodes.append(transpose)
354+
parent.data_format = "NCHW"
354355

355356
# kernel mist to be transposed
356357
if with_kernel:
358+
parent = node.inputs[1]
357359
if node.inputs[1].is_const():
358360
# kernel is const - transpose the const
359-
parent = node.inputs[1]
360361
if not parent.data_format:
361362
val = parent.get_tensor_value()
362363
val = val.transpose(HWCN_TO_NCHW)
363364
parent.set_tensor_value(val)
364-
parent.data_format = "NCHW"
365365
else:
366366
# kernel comes from op, insert transpose op
367367
input_name = node.input[1]
368368
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
369369
transpose.set_attr("perm", HWCN_TO_NCHW)
370370
transpose.inserted_nchw = True
371371
ctx.copy_shape(input_name, transpose.output[0])
372-
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(input_name), HWCN_TO_NCHW))
372+
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
373+
ctx.set_shape(transpose.output[0], new_shape)
373374
nodes.append(transpose)
375+
parent.data_format = "NCHW"
374376

375377
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
376378
if new_kernel_shape:
@@ -379,46 +381,52 @@ def calc_shape(a, b):
379381
input_name = node.input[1]
380382
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
381383
reshape.set_attr("shape", new_kernel_shape)
382-
ctx.set_shape(reshape.output[0], new_kernel_shape)
383384
else:
384385
# new reshape takes new shape as input[1]
385386
shape_name = utils.make_name(node.name)
386387
shape_node = ctx.make_const(shape_name, "Const", np.array(new_kernel_shape, dtype=np.int64))
387388
input_name = node.input[1]
388389
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
389390
reshape.input.append(shape_name)
390-
ctx.set_shape(reshape.output[0], new_kernel_shape)
391+
ctx.set_shape(reshape.output[0], new_kernel_shape)
391392
nodes.append(reshape)
392393

393394
# insert conv node after inputs
394395
nodes.append(node)
395396

396397
# transpose outputs if needed
397398
if node.is_nhwc():
398-
# TODO: what if len(output) > 0 ?
399399
for idx in output_indices:
400400
output_name = node.output[idx]
401401
op_name = utils.make_name(node.name)
402402
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
403403
transpose.set_attr("perm", NCHW_TO_NHWC)
404404
transpose.inserted_nchw = True
405-
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(node.output[idx]), NCHW_TO_NHWC))
405+
ctx.set_shape(transpose.output[0], ctx.get_shape(node.output[idx]))
406406
nodes.append(transpose)
407+
node.data_format = "NCHW"
407408
return nodes
408409

409410

410-
def add_padding(node, kernel_shape, strides):
411+
def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
411412
padding = node.get_attr("padding")
412413
if padding:
414+
if dilations is None:
415+
dilations = [1] * spatial * 2
413416
padding = padding.s.decode("utf-8")
414417
if padding == 'SAME':
415-
s_h, s_w = strides[0], strides[1]
416-
k_h, k_w = kernel_shape[0], kernel_shape[1]
417-
p_x0 = (k_w - s_w) // 2
418-
p_y0 = (k_h - s_h) // 2
419-
p_x1 = k_w - s_w - p_x0
420-
p_y1 = k_h - s_h - p_y0
421-
node.set_attr("pads", [p_y0, p_x0, p_y1, p_x1])
418+
pads = [0] * spatial * 2
419+
input_shape = ctx.get_shape(node.input[0])
420+
output_shape = ctx.get_shape(node.output[0])
421+
if node.is_nhwc():
422+
input_shape = spatial_map(input_shape, NHWC_TO_NCHW)
423+
output_shape = spatial_map(output_shape, NHWC_TO_NCHW)
424+
for i in range(spatial):
425+
pad = (output_shape[i + 2] - 1) * strides[i] + dilations[i] * kernel_shape[i] - input_shape[i + 2]
426+
pad = max(pad, 0)
427+
pads[i] = pad // 2
428+
pads[i + spatial] = pad - pad // 2
429+
node.set_attr("pads", pads)
422430
elif padding == 'VALID':
423431
pass
424432
else:
@@ -445,12 +453,11 @@ def conv_dims_attr(node, name, new_name=None):
445453
return dims
446454

447455

448-
def conv_kernel_shape(ctx, node, input_idx):
456+
def conv_kernel_shape(ctx, node, input_idx, spatial=2):
449457
kernel_shape = ctx.get_shape(node.input[1])
450-
if len(kernel_shape) != 4:
451-
raise ValueError("only Conv2D is supported")
452-
h, w, c, n = kernel_shape
453-
kernel_shape = [h, w]
458+
if len(kernel_shape) != 2 * spatial:
459+
raise ValueError("kernel rank must be 2* spatial")
460+
kernel_shape = kernel_shape[0:spatial]
454461
node.set_attr("kernel_shape", kernel_shape)
455462
return kernel_shape
456463

@@ -460,11 +467,10 @@ def conv_op(ctx, node, name, args):
460467
# @string padding, @string data_format)
461468
# T Y = Conv(T X, T W, T B, @AttrType.STRING auto_pad, @AttrType.INTS dilations, @AttrType.INT group,
462469
# @AttrType.INTS kernel_shape, @AttrType.INTS pads, @AttrType.INTS strides)
463-
kernel_shape = conv_kernel_shape(ctx, node, 1)
470+
kernel_shape = conv_kernel_shape(ctx, node, 1, spatial=2)
464471
strides = conv_dims_attr(node, "strides")
465-
conv_dims_attr(node, "dilations")
466-
add_padding(node, kernel_shape, strides)
467-
472+
dilations = conv_dims_attr(node, "dilations")
473+
add_padding(ctx, node, kernel_shape, strides, dilations=dilations, spatial=2)
468474
nodes = conv_convert_inputs(ctx, node, with_kernel=True)
469475
return nodes
470476

@@ -486,7 +492,7 @@ def convtranspose_op(ctx, node, name, args):
486492

487493
strides = conv_dims_attr(node, "strides")
488494
conv_dims_attr(node, "dilations")
489-
add_padding(node, kernel_shape, strides)
495+
add_padding(ctx, node, kernel_shape, strides)
490496

491497
# remove output_shapes input, swap data and kernel
492498
ctx.remove_input(node, node.input[0])
@@ -530,7 +536,7 @@ def depthwiseconv_op(ctx, node, name, args):
530536
strides = conv_dims_attr(node, "strides")
531537
conv_dims_attr(node, "dilations")
532538
node.set_attr("group", i_c)
533-
add_padding(node, kernel_shape, strides)
539+
add_padding(ctx, node, kernel_shape, strides)
534540

535541
new_kernel_shape = [k_output_channels, 1, k_h, k_w]
536542
nodes = conv_convert_inputs(ctx, node, with_kernel=True, new_kernel_shape=new_kernel_shape)
@@ -561,7 +567,7 @@ def pool_op(ctx, node, name, args):
561567

562568
conv_dims_attr(node, "dilations")
563569

564-
add_padding(node, kernel_shape, strides)
570+
add_padding(ctx, node, kernel_shape, strides)
565571

566572
nodes = conv_convert_inputs(ctx, node, with_kernel=False)
567573
return nodes

0 commit comments

Comments
 (0)