Skip to content

Commit 48f9c5e

Browse files
committed
support for more complex strided_slice options
1 parent de9a299 commit 48f9c5e

File tree

3 files changed

+80
-23
lines changed

3 files changed

+80
-23
lines changed

tests/test_backend.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ def get_conv_getdata(kind=1):
6969
data = [
7070
('SAME', [32, 35, 35, 288], [1, 3, 3, 1], [1, 2, 2, 1]),
7171
('SAME', [32, 35, 35, 288], [1, 2, 2, 1], [1, 2, 2, 1]),
72-
('SAME', [32, 35, 35, 288], [1, 2, 2, 1], [1, 1, 1, 1]),
73-
('SAME', [32, 35, 35, 288], [1, 5, 5, 1], [1, 1, 1, 1]),
74-
('SAME', [32, 35, 35, 288], [1, 1, 1, 1], [1, 2, 2, 1]),
7572
('SAME', [32, 35, 35, 288], [1, 1, 1, 1], [1, 1, 1, 1]),
7673
('SAME', [32, 35, 35, 288], [1, 5, 2, 1], [1, 2, 2, 1]),
7774
('SAME', [32, 35, 35, 288], [1, 2, 5, 1], [1, 2, 2, 1]),
@@ -83,7 +80,6 @@ def get_conv_getdata(kind=1):
8380
('SAME', [1, 28, 28, 3], [1, 5, 5, 1], [1, 2, 2, 1]),
8481
('SAME', [1, 28, 28, 3], [1, 5, 5, 1], [1, 1, 1, 1]),
8582
('SAME', [1, 28, 28, 3], [1, 5, 2, 1], [1, 2, 2, 1]),
86-
('SAME', [1, 28, 28, 3], [1, 2, 5, 1], [1, 1, 1, 1]),
8783
('SAME', [32, 8, 8, 2048], [1, 3, 3, 1], [1, 2, 2, 1]),
8884
('SAME', [32, 8, 8, 2048], [1, 3, 3, 1], [1, 1, 1, 1]),
8985
('VALID', [32, 35, 35, 288], [1, 3, 3, 1], [1, 1, 1, 1]),
@@ -975,7 +971,6 @@ def test_strided_slice1(self):
975971
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
976972
self.assertAllClose(expected, actual)
977973

978-
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "multiple dims not supported")
979974
def test_strided_slice2(self):
980975
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
981976
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
@@ -984,7 +979,6 @@ def test_strided_slice2(self):
984979
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
985980
self.assertAllClose(expected, actual)
986981

987-
@unittest.skip
988982
def test_strided_slice3(self):
989983
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
990984
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
@@ -993,7 +987,6 @@ def test_strided_slice3(self):
993987
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
994988
self.assertAllClose(expected, actual)
995989

996-
@unittest.skip
997990
def test_strided_slice4(self):
998991
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
999992
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
@@ -1002,7 +995,7 @@ def test_strided_slice4(self):
1002995
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
1003996
self.assertAllClose(expected, actual)
1004997

1005-
@unittest.skip
998+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "multiple dims not supported")
1006999
def test_strided_slice5(self):
10071000
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
10081001
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
@@ -1011,6 +1004,17 @@ def test_strided_slice5(self):
10111004
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
10121005
self.assertAllClose(expected, actual)
10131006

1007+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "multiple dims not supported")
1008+
def test_strided_slice6(self):
1009+
# example from here:
1010+
# https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/strided-slice
1011+
x_val = np.arange(5*6).astype("float32").reshape(5, 6)
1012+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1013+
x_ = x[2, :]
1014+
output = tf.identity(x_, name=_TFOUTPUT)
1015+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
1016+
self.assertAllClose(expected, actual)
1017+
10141018
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "fails with schema error")
10151019
def test_batchnorm(self):
10161020
x_shape = [1, 28, 28, 2]
@@ -1075,7 +1079,7 @@ def test_fill(self):
10751079

10761080
if __name__ == "__main__":
10771081
parser = argparse.ArgumentParser()
1078-
parser.add_argument('--backend', default='caffe2',
1082+
parser.add_argument('--backend', default=BACKEND,
10791083
choices=["caffe2", "onnxmsrt", "onnxmsrtnext", "onnx-tensorflow"],
10801084
help="backend to test against")
10811085
parser.add_argument('--opset', default=OPSET,

tf2onnx/graph.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, node, graph):
3131
for a in node.attribute:
3232
self._attr[a.name] = a
3333
# try to find a dtype for this node
34-
dtype = graph.dtypes.get(node.name)
34+
dtype = graph._dtypes.get(node.name)
3535
if not dtype:
3636
dtype = self._attr.get("dtype")
3737
if dtype:
@@ -105,9 +105,10 @@ def __str__(self):
105105
def __repr__(self):
106106
return "<onnx op type='%s' name=%s>" % (self.type, self._op.name)
107107

108-
def get_attr(self, name):
108+
def get_attr(self, name, default=None):
109109
"""Get attribute map."""
110-
return self.attr.get(name)
110+
attr = self.attr.get(name, default)
111+
return attr
111112

112113
def set_attr(self, name, value):
113114
self.attr[name] = helper.make_attribute(name, value)
@@ -236,7 +237,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
236237
self.shapes = {}
237238
self.model_inputs = []
238239
self._target = set(target)
239-
self.dtypes = dtypes
240+
self._dtypes = dtypes
240241
self._output_shapes = output_shapes
241242
ops = [Node(node, self) for node in nodes]
242243
self.set_nodes(ops)
@@ -290,6 +291,10 @@ def add_initializer(self, tensor):
290291
"""Add tensor to initializers."""
291292
self._initializers[tensor.name] = tensor
292293

294+
def get_dtype(self, name):
295+
"""Get dtype for node."""
296+
return self._dtypes.get(name)
297+
293298
def get_shape(self, name):
294299
"""Get shape for node."""
295300
assert isinstance(name, str)

tf2onnx/tfonnx.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -733,24 +733,72 @@ def expanddims_op7(ctx, node, name, args):
733733

734734

735735
def stridedslice_op(ctx, node, name, args):
736-
# only the cases strides=1 can be mapped to onnx
737-
not_supported_attr = ["begin_mask", "ellipsis_mask", "end_mask", "new_axis_mask", "shrink_axis_mask"]
736+
# for now we implement common cases. Things like strides!=1 are not mappable to onnx.
737+
not_supported_attr = ["ellipsis_mask", "new_axis_mask"]
738738
for attr_name in not_supported_attr:
739739
attr = node.get_attr(attr_name)
740740
if attr is not None and attr.i != 0:
741-
raise ValueError("StridedSlice: attribute " + attr_name + " must be 0")
741+
raise ValueError("StridedSlice: attribute " + attr_name + " not supported")
742+
742743
begin = node.inputs[1].get_tensor_value()
743744
end = node.inputs[2].get_tensor_value()
744-
strides = node.inputs[3].get_tensor_value()[0]
745-
if strides != 1:
746-
raise ValueError("StridedSlice: only strides=1 is supported")
747-
node.set_attr("starts", list(begin))
748-
node.set_attr("ends", list(end))
745+
strides = node.inputs[3].get_tensor_value()
746+
end_mask = node.get_attr("end_mask")
747+
end_mask = end_mask.i if end_mask is not None else 0
748+
shrink_axis_mask = node.get_attr("shrink_axis_mask")
749+
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
750+
new_begin = []
751+
new_end = []
752+
axes = []
753+
# onnx slice op can't remove a axis, track axis and add a squeeze op if needed
754+
needs_squeeze = []
755+
for idx in range(len(begin)):
756+
if strides[idx] != 1:
757+
raise ValueError("StridedSlice: only strides=1 is supported")
758+
axes.append(idx)
759+
mask = (shrink_axis_mask >> idx) & 1
760+
if mask != 0:
761+
new_begin.append(begin[idx])
762+
new_end.append(end[idx])
763+
needs_squeeze.append(idx)
764+
continue
765+
766+
new_begin.append(begin[idx])
767+
mask = (end_mask >> idx) & 1
768+
if mask != 0:
769+
new_end.append(sys.maxsize)
770+
else:
771+
new_end.append(end[idx])
772+
773+
node.set_attr("starts", new_begin)
774+
node.set_attr("ends", new_end)
775+
node.set_attr("axes", axes)
749776
node.type = "Slice"
750777
ctx.remove_input(node, node.input[3])
751778
ctx.remove_input(node, node.input[2])
752779
ctx.remove_input(node, node.input[1])
753-
return node
780+
nodes = [node]
781+
if needs_squeeze:
782+
name = utils.make_name(node.name)
783+
squeeze_op = ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
784+
squeeze_op.set_attr("axes", needs_squeeze)
785+
nodes.append(squeeze_op)
786+
ctx.copy_shape(node.output[0], squeeze_op.output[0])
787+
788+
# onnx slice as of opset 7 does only take float tensors ... cast if needed
789+
input_dtype = ctx.get_dtype(node.input[0])
790+
if input_dtype in [onnx_pb.TensorProto.INT32, onnx_pb.TensorProto.INT64]:
791+
cast_op = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
792+
cast_op.set_attr("to", onnx_pb.TensorProto.FLOAT)
793+
ctx.copy_shape(node.input[0], cast_op.output[0])
794+
nodes.insert(0, cast_op)
795+
name = utils.make_name(node.name)
796+
cast_op = ctx.insert_new_node_on_output("Cast", nodes[-1].output[0], name)
797+
cast_op.set_attr("to", input_dtype)
798+
ctx.copy_shape(node.input[0], cast_op.output[0])
799+
nodes.append(cast_op)
800+
801+
return nodes
754802

755803

756804
def pow_op(ctx, node, name, args):
@@ -1267,7 +1315,7 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
12671315
def tf_optimize(sess, inputs, outputs, graph_def):
12681316
"""Optimize tensorflow graph for inference."""
12691317
transforms = [
1270-
"fold_constants(ignore_errors=true)",
1318+
#"fold_constants(ignore_errors=true)",
12711319
"fold_batch_norms",
12721320
"fold_old_batch_norms",
12731321
]

0 commit comments

Comments
 (0)