Skip to content

Commit 563b0f7

Browse files
authored
Merge pull request #464 from zhijxu-MS/push_branch
support slice-10
2 parents 6ae0320 + 1572b6b commit 563b0f7

File tree

13 files changed

+313
-102
lines changed

13 files changed

+313
-102
lines changed

tests/test_backend.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,42 @@ def test_slice(self):
945945
_ = tf.identity(x_, name=_TFOUTPUT)
946946
self._run_test_case([_OUTPUT], {_INPUT: x_val})
947947

948+
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
949+
def test_slice_with_non_const(self):
950+
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
951+
t1 = np.array([0, 1], dtype=np.int32)
952+
t2 = np.array([2, 2], dtype=np.int32)
953+
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
954+
t1_ = tf.placeholder(tf.int32, t1.shape, name=_TFINPUT1)
955+
t2_ = tf.placeholder(tf.int32, t2.shape, name=_TFINPUT2)
956+
x_ = tf.slice(x0, t1_, t2_)
957+
_ = tf.identity(x_, name=_TFOUTPUT)
958+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: t1, _INPUT2: t2})
959+
960+
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
961+
def test_slice_with_size_is_negative_one(self):
962+
x_val = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32)
963+
t1 = np.array([0, 1], dtype=np.int32)
964+
# input "size" contains -1
965+
t2 = np.array([2, -1], dtype=np.int32)
966+
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
967+
t1_ = tf.placeholder(tf.int32, t1.shape, name=_TFINPUT1)
968+
t2_ = tf.placeholder(tf.int32, t2.shape, name=_TFINPUT2)
969+
x_ = tf.slice(x0, t1_, t2_)
970+
_ = tf.identity(x_, name=_TFOUTPUT)
971+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: t1, _INPUT2: t2})
972+
973+
@skip_caffe2_backend()
974+
def test_slice1(self):
975+
# FIXME: only 1 dimension supported by caffe2 and msrt
976+
x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]], dtype=np.float32)
977+
t1 = tf.constant([1, 0, 0], dtype=tf.int32)
978+
t2 = tf.constant([1, 1, 3], dtype=tf.int32)
979+
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
980+
x_ = tf.slice(x0, t1, t2)
981+
_ = tf.identity(x_, name=_TFOUTPUT)
982+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
983+
948984
def test_split(self):
949985
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape(5, 30)
950986
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
@@ -1103,17 +1139,6 @@ def test_reducemean(self):
11031139
_ = tf.identity(x_, name=_TFOUTPUT)
11041140
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11051141

1106-
@unittest.skip("")
1107-
def test_slice1(self):
1108-
# FIXME: only 1 dimension supported by caffe2 and msrt
1109-
x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]], dtype=np.float32)
1110-
t1 = tf.constant([1, 0, 0], dtype=tf.int32)
1111-
t2 = tf.constant([1, 1, 3], dtype=tf.int32)
1112-
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
1113-
x_ = tf.slice(x0, t1, t2)
1114-
_ = tf.identity(x_, name=_TFOUTPUT)
1115-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1116-
11171142
@skip_caffe2_backend()
11181143
@check_onnxruntime_incompatibility("Pow")
11191144
def test_pow_scalar(self):
@@ -2078,5 +2103,6 @@ def test_space_to_batchnd(self):
20782103
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
20792104
self._run_test_case([_OUTPUT], {_INPUT: input_val})
20802105

2106+
20812107
if __name__ == '__main__':
20822108
unittest_main()

tf2onnx/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from __future__ import print_function
77
from __future__ import unicode_literals
88

9-
__all__ = ["utils", "graph_matcher", "graph", "loader", "tfonnx", "shape_inference", "schemas"]
9+
__all__ = ["utils", "graph_matcher", "graph", "graph_builder", "loader", "tfonnx", "shape_inference", "schemas"]
1010

1111
from .version import version as __version__
1212
from . import verbose_logging as logging
13-
from tf2onnx import tfonnx, utils, graph, graph_matcher, shape_inference, schemas # pylint: disable=wrong-import-order
13+
from tf2onnx import tfonnx, utils, graph, graph_builder, graph_matcher, shape_inference, schemas # pylint: disable=wrong-import-order

tf2onnx/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tf2onnx.utils import port_name, find_opset
2121
from tf2onnx import optimizer
2222
from tf2onnx.schemas import get_schema, infer_onnx_shape_dtype
23+
from tf2onnx import constants
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -419,7 +420,8 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
419420
return node
420421

421422
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
422-
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None, infer_shape_dtype=True):
423+
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=constants.ONNX_DOMAIN,
424+
infer_shape_dtype=True):
423425
"""Make a new onnx node in the graph"""
424426
if attr is None:
425427
attr = {}

tf2onnx/graph_builder.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.graph_helper - class to help building graph, such as helping to make complex node
6+
"""
7+
8+
import numpy as np
9+
from tf2onnx import utils, logging
10+
11+
12+
# pylint: disable=missing-docstring
13+
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class GraphBuilder(object):
19+
"""help to build graph"""
20+
def __init__(self, graph):
21+
self._g = graph
22+
23+
@property
24+
def graph(self):
25+
return self._g
26+
27+
def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
28+
"""
29+
slice changes its schema at opset 10: it treats some attributes as dynamic input
30+
so this function has to process inputs according to graph's opset version
31+
to get "inputs" and "attr" to feed "make_node"
32+
kwargs: key could be ["data", "starts", "ends", "axes", "steps", "outputs"].
33+
"""
34+
outputs = kwargs.pop("outputs", None)
35+
36+
if self.graph.opset < 10:
37+
# "data" is string
38+
# "starts", "ends" and "axes" are attributes, and "axes" is optional.
39+
inputs = [kwargs.pop("data")]
40+
starts = self.convert_to_attribute(kwargs.pop("starts"))
41+
ends = self.convert_to_attribute(kwargs.pop("ends"))
42+
axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True)
43+
attr = {"starts": starts, "ends": ends, "axes": axes}
44+
else:
45+
# slice-10 has 3 required inputs "data", "starts", "ends"l
46+
# and 2 optional inputs "axes", "steps"
47+
# input sequence should be "data", "starts", "ends", "axes", "steps"
48+
attr = {}
49+
data = self.convert_to_input(kwargs.pop("data"))
50+
starts = self.convert_to_input(kwargs.pop("starts"))
51+
ends = self.convert_to_input(kwargs.pop("ends"))
52+
axes = self.convert_to_input(kwargs.pop("axes", None), is_optional=True)
53+
steps = self.convert_to_input(kwargs.pop("steps", None), is_optional=True)
54+
inputs = [data, starts, ends, axes, steps]
55+
56+
# pro-process inputs and attr
57+
if kwargs:
58+
logger.warning("kwargs contains un-used key")
59+
60+
new_attr = {}
61+
for key, val in attr.items():
62+
if val is not None:
63+
new_attr[key] = val
64+
attr = new_attr
65+
66+
for ind, val in enumerate(inputs):
67+
if val is None:
68+
inputs[ind] = "" # empty string means no connection in ONNX
69+
# remove tailing ""
70+
while inputs[-1] == "":
71+
inputs = inputs[:-1]
72+
73+
return self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, name=name,
74+
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
75+
76+
def convert_to_input(self, tensor, is_optional=False):
77+
"""in ONNX, input shold come from node, so it must be a string"""
78+
if is_optional and tensor is None:
79+
return None
80+
81+
utils.make_sure(tensor is not None, "input is required so it couldn't be None")
82+
83+
res = tensor
84+
if isinstance(tensor, list):
85+
res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor)).output[0]
86+
87+
utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed")
88+
89+
return res
90+
91+
def convert_to_attribute(self, tensor, is_optional=False):
92+
if is_optional and tensor is None:
93+
return None
94+
95+
utils.make_sure(tensor is not None, "input is required so it couldn't be None")
96+
97+
res = tensor
98+
if isinstance(tensor, str):
99+
const_node = self.graph.get_node_by_output(tensor)
100+
res = const_node.get_tensor_value(as_list=True)
101+
102+
utils.make_sure(isinstance(res, list), "input is an attr, so a list is needed")
103+
104+
return res

tf2onnx/onnx_opset/nn.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from onnx import onnx_pb
1616
from onnx.onnx_pb import TensorProto
1717
from tf2onnx import constants, utils
18+
from tf2onnx.graph_builder import GraphBuilder
1819
from tf2onnx.handler import tf_op
1920
from tf2onnx.onnx_opset import common, controlflow, tensor
2021

@@ -463,6 +464,14 @@ def version_7(cls, ctx, node, **kwargs):
463464

464465
@classmethod
465466
def version_9(cls, ctx, node, **kwargs):
467+
cls._convert_since_9(ctx, node, **kwargs)
468+
469+
@classmethod
470+
def version_10(cls, ctx, node, **kwargs):
471+
cls._convert_since_9(ctx, node, **kwargs)
472+
473+
@classmethod
474+
def _convert_since_9(cls, ctx, node, **kwargs):
466475
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
467476
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
468477
# wants the input to be NHWC - adjust target_shape to this.
@@ -481,8 +490,10 @@ def version_9(cls, ctx, node, **kwargs):
481490
scales = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
482491
else:
483492
ori_shape = ctx.make_node("Shape", [node.input[0]])
484-
ori_shape_hw = ctx.make_node("Slice", ori_shape.output, {"axes": [0], "starts": [1], "ends": [3]})
485-
ori_shape_hw_float = ctx.make_node("Cast", ori_shape_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
493+
attr = {"axes": [0], "starts": [1], "ends": [3]}
494+
inputs_map = {"data": ori_shape.output[0], **attr}
495+
ori_shape_hw = GraphBuilder(ctx).make_slice(inputs_map)
496+
ori_shape_hw_float = ctx.make_node("Cast", [ori_shape_hw], attr={"to": onnx_pb.TensorProto.FLOAT})
486497

487498
target_hw = node.inputs[1]
488499
target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})
@@ -538,12 +549,13 @@ def version_7(cls, ctx, node, **kwargs):
538549
new_line = g.make_node(op_type="Concat", inputs=[const_zero_bool.output[0], "line"],
539550
attr={"axis": counter_axis},
540551
dtypes=[onnx_pb.TensorProto.BOOL])
541-
slice_node = g.make_node(op_type="Slice", inputs=[new_line.output[0]],
542-
attr={"axes": [counter_axis], "starts": [0], "ends": [-1]})
552+
attr = {"axes": [counter_axis], "starts": [0], "ends": [-1]}
553+
inputs_map = {"data": new_line.output[0], **attr}
554+
slice_node = GraphBuilder(g).make_slice(inputs_map)
543555

544556
g.make_node("Identity", ["cond"], outputs=["cond_out"])
545557
g.make_node("Identity", ["line"], outputs=["res"])
546-
g.make_node("Identity", [slice_node.output[0]], outputs=["line_out"])
558+
g.make_node("Identity", [slice_node], outputs=["line_out"])
547559

548560
g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
549561
g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])

0 commit comments

Comments
 (0)