Skip to content

Commit cdabebf

Browse files
author
wayuanho
committed
fuse Pad op before Conv2D op into Conv op as an attribute
1 parent 5810313 commit cdabebf

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

tests/test_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,18 @@ def test_conv2d_8(self):
288288
_ = tf.identity(conv, name=_TFOUTPUT)
289289
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-5)
290290

291+
def test_conv2d_with_pad(self):
292+
x_val = make_xval((1, 1, 5, 5)).transpose(NCHW_TO_NHWC)
293+
w = np.random.random_sample([3, 3, 1, 2]).astype(np.float32)
294+
strides = [1, 1, 1, 1]
295+
296+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
297+
kernel = tf.constant(w, dtype=tf.float32, name='k')
298+
x_pad = tf.pad(x, paddings=[[0, 0], [2, 2], [2, 2], [0, 0]])
299+
conv = tf.nn.conv2d(x_pad, kernel, strides=strides, padding="VALID")
300+
_ = tf.identity(conv, name=_TFOUTPUT)
301+
self._run_test_case([_OUTPUT], {_INPUT: x_val}, rtol=1e-5)
302+
291303
def test_conv2d_transpose(self):
292304
x_shape = [2, 6, 4, 3]
293305
output_shape = [2, 13, 9, 2]

tf2onnx/graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ def set_deleted(self):
150150
def is_deleted(self):
151151
return self.type == "@@DELETED@@"
152152

153+
@property
154+
def skip_conversion(self):
155+
return self._skip_conversion
156+
157+
@skip_conversion.setter
158+
def skip_conversion(self, val):
159+
self._skip_conversion = val
160+
153161
# If some Node is created as onnx_node, then we don't need convert it
154162
def need_skip(self):
155163
return self._skip_conversion

tf2onnx/tfonnx.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
383383
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
384384
transpose.set_attr("perm", NHWC_TO_NCHW)
385385
transpose.inserted_nchw = True
386+
transpose.skip_conversion = True
386387
shape = ctx.get_shape(input_name)
387388
new_shape = spatial_map(shape, NHWC_TO_NCHW)
388389
ctx.set_shape(transpose.output[0], new_shape)
@@ -399,6 +400,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
399400
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
400401
transpose.set_attr("perm", HWCN_TO_NCHW)
401402
transpose.inserted_nchw = True
403+
transpose.skip_conversion = True
402404
ctx.copy_shape(input_name, transpose.output[0])
403405
new_shape = spatial_map(ctx.get_shape(input_name), HWCN_TO_NCHW)
404406
ctx.set_shape(transpose.output[0], new_shape)
@@ -412,13 +414,15 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
412414
input_name = node.input[1]
413415
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
414416
reshape.set_attr("shape", new_kernel_shape)
417+
reshape.skip_conversion = True
415418
else:
416419
# new reshape takes new shape as input[1]
417420
shape_name = utils.make_name(node.name)
418421
nodes.append(ctx.make_const(shape_name, np.array(new_kernel_shape, dtype=np.int64)))
419422
input_name = node.input[1]
420423
reshape = ctx.insert_new_node_on_input(node, "Reshape", input_name)
421424
reshape.input.append(shape_name)
425+
reshape.skip_conversion = True
422426
ctx.set_shape(reshape.output[0], new_kernel_shape)
423427
nodes.append(reshape)
424428

@@ -433,6 +437,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
433437
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
434438
transpose.set_attr("perm", NCHW_TO_NHWC)
435439
transpose.inserted_nchw = True
440+
transpose.skip_conversion = True
436441
ctx.set_shape(transpose.output[0], ctx.get_shape(node.output[idx]))
437442
nodes.append(transpose)
438443
node.data_format = "NCHW"
@@ -2280,6 +2285,47 @@ def rewrite_incomplete_type_support_rs6(g, ops):
22802285
return rewrite_incomplete_type_support(g, ops, ["Div", "ReduceSum", "Slice", "Split", "Tile", "Transpose"])
22812286

22822287

2288+
def rewrite_conv2d_with_pad(g, ops):
2289+
pattern = \
2290+
OpTypePattern("Conv2D", name="conv", inputs=[
2291+
OpTypePattern("Pad", name="pad"),
2292+
OpTypePattern("*")
2293+
])
2294+
matcher = GraphMatcher(pattern)
2295+
match_results = list(matcher.match_ops(ops))
2296+
for match in match_results:
2297+
conv = match.get_op("conv")
2298+
pad = match.get_op("pad")
2299+
paddings = pad.inputs[1]
2300+
2301+
if not paddings.is_const():
2302+
return ops
2303+
mode = pad.get_attr("mode")
2304+
if mode:
2305+
mode = mode.s.decode("utf-8").lower()
2306+
if mode not in [None, "constant"] or len(pad.input) >= 3:
2307+
return ops
2308+
# Conv2D already has a pad
2309+
if conv.get_attr("padding") == "SAME":
2310+
return ops
2311+
2312+
log.debug("merge pad [%s] into conv [%s]", pad.name, conv.name)
2313+
paddings_val = np.array(paddings.get_tensor_value())
2314+
# can't pad on batch or channel dimensions
2315+
if np.any(paddings_val[0]) or np.any(paddings_val[3]):
2316+
return ops
2317+
paddings_val = paddings_val[1:3]
2318+
paddings_val = paddings_val.transpose().flatten()
2319+
g.replace_input(conv, conv.input[0], pad.input[0])
2320+
# convert Conv2D
2321+
conv.type = "Conv"
2322+
ops.extend(conv_op(g, conv, conv.name, []))
2323+
conv.skip_conversion = True
2324+
conv.set_attr("auto_pad", "NOTSET")
2325+
conv.set_attr("pads", paddings_val)
2326+
return ops
2327+
2328+
22832329
def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
22842330
mapped_op = collections.Counter()
22852331
unmapped_op = collections.Counter()
@@ -2476,7 +2522,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24762522
# bi-directional re-writer should be placed after single directional re-writer
24772523
rewriters = [rewrite_transpose, rewrite_flatten,
24782524
rewrite_random_uniform, rewrite_random_uniform_fold_const,
2479-
rewrite_random_normal, rewrite_dropout, rewrite_leakyrelu,
2525+
rewrite_random_normal, rewrite_dropout,
2526+
rewrite_leakyrelu, rewrite_conv2d_with_pad,
24802527
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
24812528
rewrite_single_direction_gru, rewrite_single_direction_grublock,
24822529
rewrite_bi_direction_gru, rewrite_logical_compare_with_equal,

0 commit comments

Comments
 (0)