Skip to content

Commit 90f3689

Browse files
Implement symmetric padding mode (#1698)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 186b954 commit 90f3689

File tree

2 files changed

+98
-8
lines changed

2 files changed

+98
-8
lines changed

tests/test_backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,6 +2044,24 @@ def func(x):
20442044
return tf.identity(op, name=_TFOUTPUT)
20452045
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
20462046

2047+
@check_opset_min_version(9, "Compress")
2048+
def test_pad_symmetric(self):
2049+
x_val = make_xval([4, 1, 5])
2050+
def func(x):
2051+
paddings = tf.constant([[1, 3], [0, 0], [2, 4]], name="paddings")
2052+
op = tf.pad(x, paddings, mode="SYMMETRIC", name="symmetric")
2053+
return tf.identity(op, name=_TFOUTPUT)
2054+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
2055+
2056+
@check_opset_min_version(11, "Pad")
2057+
def test_dynamic_pad_symmetric(self):
2058+
x_val = make_xval([4, 1, 5])
2059+
y_val = np.array([[1, 3], [0, 0], [2, 4]], np.int32)
2060+
def func(x, y):
2061+
op = tf.pad(x, y, mode="SYMMETRIC", name="symmetric")
2062+
return tf.identity(op, name=_TFOUTPUT)
2063+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
2064+
20472065
@skip_caffe2_backend()
20482066
def test_randomuniform(self):
20492067
def func():

tf2onnx/onnx_opset/nn.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,67 @@ def version_7(cls, ctx, node, **kwargs):
800800

801801
@tf_op(["Pad", "PadV2", "MirrorPad"], onnx_op="Pad")
802802
class Pad:
803+
804+
@classmethod
805+
def convert_symmetric_pads(cls, ctx, node):
806+
"""Currently there isn't a symmetric padding mode in ONNX so we add a dummy row then use the reflect mode
807+
and remove the dummy row with compress. Ex: 1234 -> 012340 -> 2101234043 -> 21123443. Only do this to
808+
dims with non-zero pads (if pads are constant)"""
809+
rank = ctx.get_rank(node.input[0])
810+
utils.make_sure(rank is not None, "Cannot convert pad with symmetric mode and unknown rank")
811+
utils.make_sure(ctx.opset >= 9, "opset 9 required for symmetric padding mode")
812+
node.set_attr("mode", "reflect")
813+
const_pads = None
814+
consumers = ctx.find_output_consumers(node.output[0])
815+
output_shape = ctx.get_shape(node.output[0])
816+
if ctx.opset < 11:
817+
const_pads = node.get_attr_value("pads")
818+
elif node.inputs[1].is_const():
819+
const_pads = node.inputs[1].get_tensor_value()
820+
non_zero_axes = list(range(rank))
821+
if const_pads is not None:
822+
non_zero_axes = []
823+
for i in range(rank):
824+
if const_pads[i] != 0 or const_pads[i + rank] != 0:
825+
non_zero_axes.append(i)
826+
827+
inc_pads = [0] * (rank * 2)
828+
for a in non_zero_axes:
829+
inc_pads[a] = 1
830+
inc_pads[a + rank] = 1
831+
832+
if ctx.opset < 11:
833+
padded_inp = ctx.make_node("Pad", [node.input[0]], attr={'mode': 'constant', 'pads': inc_pads}).output[0]
834+
else:
835+
pad1_pads_const = ctx.make_const(utils.make_name("pad1_pads"), np.array(inc_pads, np.int64)).output[0]
836+
padded_inp = ctx.make_node("Pad", [node.input[0], pad1_pads_const], attr={'mode': 'constant'}).output[0]
837+
ctx.replace_input(node, node.input[0], padded_inp, 0)
838+
ctx.update_node_shape_dtype(node, override=True)
839+
840+
output = node.output[0]
841+
shape = ctx.make_node("Shape", [output]).output[0]
842+
dims = ctx.make_node("Split", [shape], output_count=rank).output
843+
two_false = ctx.make_const(utils.make_name("two_false"), np.array([False, False], np.bool)).output[0]
844+
inv_second = ctx.make_const(utils.make_name("inv_second"), np.array([1, -1], np.int64)).output[0]
845+
dec_second = ctx.make_const(utils.make_name("dec_second"), np.array([0, 1], np.int64)).output[0]
846+
for a in non_zero_axes:
847+
one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.BOOL, dims=[1], vals=[1])
848+
ones_of_shape = ctx.make_node("ConstantOfShape", [dims[a]], attr={'value': one_tensor}).output[0]
849+
if const_pads is not None:
850+
to_remove_val = [const_pads[a], -1 - const_pads[a + rank]]
851+
to_remove = ctx.make_const(utils.make_name("to_remove"), np.array(to_remove_val, np.int64)).output[0]
852+
else:
853+
pads_idx = ctx.make_const(utils.make_name("pads_idx"), np.array([a, a + rank], np.int64)).output[0]
854+
pads_vals = ctx.make_node("Gather", [node.input[1], pads_idx]).output[0]
855+
pads_inv_second = ctx.make_node("Mul", [pads_vals, inv_second]).output[0]
856+
to_remove = ctx.make_node("Sub", [pads_inv_second, dec_second]).output[0]
857+
scatter_op = "ScatterElements" if ctx.opset >= 11 else "Scatter"
858+
dims_to_keep = ctx.make_node(scatter_op, [ones_of_shape, to_remove, two_false]).output[0]
859+
compress = ctx.make_node("Compress", [output, dims_to_keep], attr={'axis': a})
860+
output = compress.output[0]
861+
ctx.replace_all_inputs(node.output[0], output, consumers)
862+
ctx.set_shape(output, output_shape)
863+
803864
@classmethod
804865
def version_1(cls, ctx, node, **kwargs):
805866
node.type = "Pad"
@@ -812,7 +873,7 @@ def version_1(cls, ctx, node, **kwargs):
812873
if mode:
813874
mode = mode.s.decode("utf-8").lower()
814875
node.set_attr("mode", mode)
815-
if mode not in [None, "constant", "reflect"]:
876+
if mode not in [None, "symmetric", "constant", "reflect"]:
816877
raise ValueError(mode + " pad mode is not supported")
817878

818879
if mode in [None, "constant"] and len(node.input) == 3:
@@ -836,21 +897,29 @@ def version_1(cls, ctx, node, **kwargs):
836897
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
837898
ctx.copy_shape(node.name, cast_back_node.output[0])
838899

900+
if mode == "symmetric":
901+
cls.convert_symmetric_pads(ctx, node)
902+
839903
@classmethod
840904
def version_11(cls, ctx, node, **kwargs):
841905
mode = node.get_attr("mode")
842906
if mode:
843907
mode = mode.s.decode("utf-8").lower()
844908
node.set_attr("mode", mode)
845-
if mode not in [None, "constant", "reflect"]:
909+
if mode not in [None, "symmetric", "constant", "reflect"]:
846910
raise ValueError(mode + " pad mode is not supported")
847911

848-
# pads must be int64.
849-
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
850-
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
851-
ctx.insert_new_node_on_input(node, "Transpose", node.input[1])
852-
shape_const = ctx.make_const(utils.make_name(node.name), np.array([-1]).astype(np.int64))
853-
ctx.insert_new_node_on_input(node, "Reshape", [node.input[1], shape_const.name])
912+
if not node.inputs[1].is_const():
913+
# pads must be int64.
914+
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
915+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
916+
ctx.insert_new_node_on_input(node, "Transpose", node.input[1])
917+
shape_const = ctx.make_const(utils.make_name(node.name), np.array([-1]).astype(np.int64))
918+
ctx.insert_new_node_on_input(node, "Reshape", [node.input[1], shape_const.name])
919+
else:
920+
paddings = node.inputs[1].get_tensor_value(as_list=False).astype(np.int64).transpose().flatten()
921+
pad_const = ctx.make_const(utils.make_name("pad_const"), paddings)
922+
ctx.replace_input(node, node.input[1], pad_const.output[0], 1)
854923

855924
origin_dtype = ctx.get_dtype(node.output[0])
856925
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
@@ -865,6 +934,9 @@ def version_11(cls, ctx, node, **kwargs):
865934
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
866935
ctx.copy_shape(node.name, cast_back_node.output[0])
867936

937+
if mode == "symmetric":
938+
cls.convert_symmetric_pads(ctx, node)
939+
868940

869941
@tf_op(["FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3"])
870942
class BatchNorm:

0 commit comments

Comments
 (0)