Skip to content

Commit 0411ca0

Browse files
authored
Merge pull request #37 from onnx/gs/onnx-1.2
added support for topk, stridedslice(limited), custom opset
2 parents 767432b + 4102833 commit 0411ca0

File tree

6 files changed

+85
-26
lines changed

6 files changed

+85
-26
lines changed

tests/test_backend.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def run_onnxmsrt(onnx_graph, inputs, output_names, test_name):
8484
model_path = os.path.join(TMPPATH, test_name + ".pb")
8585
with open(model_path, "wb") as f:
8686
f.write(onnx_graph.SerializeToString())
87-
8887
m = lotus.ModelExecutor(model_path)
8988
results = m.run(output_names, inputs)
9089
return results[0]
@@ -96,7 +95,6 @@ def run_onnxmsrtnext(onnx_graph, inputs, output_names, test_name):
9695
model_path = os.path.join(TMPPATH, test_name + ".pb")
9796
with open(model_path, "wb") as f:
9897
f.write(onnx_graph.SerializeToString())
99-
10098
m = lotus.InferenceSession(model_path)
10199
results = m.run(output_names, inputs)
102100
return results[0]
@@ -764,19 +762,32 @@ def test_cancel_transpose(self):
764762
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
765763
self.assertAllClose(expected, actual)
766764

767-
@unittest.skip
768-
def test_strided_slice0(self):
769-
# FIXME: not implemented yet
770-
x_val = np.array([
771-
[[1, 1, 1], [2, 2, 2]],
772-
[[3, 3, 3], [4, 4, 4]],
773-
[[5, 5, 5], [6, 6, 6]]], dtype=np.float32)
765+
def test_topk(self):
766+
x_val = np.arange(3*2*3).astype("float32")
767+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
768+
values, indices = tf.nn.top_k(x, 5, sorted=True)
769+
output = tf.identity(values, name=_TFOUTPUT)
770+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
771+
self.assertAllClose(expected, actual)
772+
773+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "multiple dims not supported")
774+
def test_strided_slice1(self):
775+
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
774776
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
775777
x_ = tf.strided_slice(x, [1, 0, 0], [2, 1, 3], [1, 1, 1])
776778
output = tf.identity(x_, name=_TFOUTPUT)
777779
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
778780
self.assertAllClose(expected, actual)
779781

782+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "multiple dims not supported")
783+
def test_strided_slice2(self):
784+
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
785+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
786+
x_ = tf.strided_slice(x, [1, 0, 0], [2, 2, 3], [1, 1, 1])
787+
output = tf.identity(x_, name=_TFOUTPUT)
788+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
789+
self.assertAllClose(expected, actual)
790+
780791
@unittest.skip
781792
def test_resize_nearest_neighbor(self):
782793
# this should work but no runtime I tried supports it.

tests/test_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,17 @@ def print_handler(ctx, node, name, args):
266266
# becomes:
267267
# T output = Identity(T Input)
268268
node.type = "Identity"
269+
node.domain = "tf"
269270
del node.input[1:]
270271
return node
271272

272273
with tf.Session() as sess:
273274
x = tf.placeholder(tf.float32, [2, 3], name="input1")
274275
x_ = tf.Print(x, [x], "hello")
275276
_ = tf.identity(x_, name="output")
276-
g = process_tf_graph(sess.graph, custom_op_handlers={"Print": print_handler})
277+
g = process_tf_graph(sess.graph,
278+
custom_op_handlers={"Print": print_handler},
279+
extra_opset=helper.make_opsetid("tf", 1))
277280
self.assertEqual(
278281
'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
279282
onnx_to_graphviz(g))

tf2onnx/convert.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tensorflow as tf
1313
import tf2onnx.utils
1414
from tf2onnx.tfonnx import process_tf_graph, tf_optimize, DEFAULT_TARGET, POSSIBLE_TARGETS
15+
from onnx import helper
1516

1617

1718
def get_args():
@@ -44,7 +45,7 @@ def get_args():
4445

4546

4647
def default_custom_op_handler(ctx, node, name, args):
47-
node.type = "tf." + node.type
48+
node.domain = "tf"
4849
return node
4950

5051

@@ -57,7 +58,13 @@ def main():
5758
# support unknown dimensions.
5859
tf2onnx.utils.ONNX_UNKNOWN_DIMENSION = args.unknown_dim
5960

60-
custom_ops = {op: default_custom_op_handler for op in args.custom_ops.split(",")} if args.custom_ops else {}
61+
if args.custom_ops:
62+
# default custom ops for tensorflow-onnx are in the "tf" namespace
63+
custom_ops = {op: default_custom_op_handler for op in args.custom_ops.split(",")}
64+
extra_opset = [helper.make_opsetid("tf", 1)]
65+
else:
66+
args.custom_ops = {}
67+
extra_opset = None
6168

6269
graph_def = tf.GraphDef()
6370
with tf.gfile.FastGFile(args.input, 'rb') as f:
@@ -71,7 +78,8 @@ def main():
7178
verbose=args.verbose,
7279
target=args.target,
7380
opset=args.opset,
74-
custom_op_handlers=custom_ops)
81+
custom_op_handlers=custom_ops,
82+
extra_opset=extra_opset)
7583

7684
model_proto = g.make_model(
7785
"converted from {}".format(args.input), args.inputs, args.outputs,

tf2onnx/graph.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ def type(self, val):
8181
"""Set Op type."""
8282
self._op.op_type = val
8383

84+
@property
85+
def domain(self):
86+
"""Return Op type."""
87+
return self._op.domain
88+
89+
@type.setter
90+
def domain(self, val):
91+
"""Set Op type."""
92+
self._op.domain = val
93+
8494
def is_nhwc(self):
8595
"""Return True if node is in NCHW format."""
8696
return self.data_format == "NHWC"
@@ -200,7 +210,7 @@ def update_proto(self):
200210
class Graph(object):
201211
""""Class that provides graph manipulation and matching."""
202212

203-
def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=None):
213+
def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=None, extra_opset=None):
204214
"""Create Graph.
205215
Args:
206216
nodes: list of Node()
@@ -219,9 +229,10 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
219229
self._output_shapes = output_shapes
220230
ops = [Node(node, self) for node in nodes]
221231
self.set_nodes(ops)
222-
if opset is None:
232+
if opset is None or opset == 0:
223233
opset = defs.onnx_opset_version()
224234
self._opset = opset
235+
self._extra_opset = extra_opset
225236

226237
@property
227238
def opset(self):
@@ -401,10 +412,13 @@ def make_model(self, doc, input_names, output_names, optimize=True):
401412

402413
kwargs = {"producer_name": "tf2onnx",
403414
"producer_version": __version__}
404-
if self._opset > 0:
405-
imp = OperatorSetIdProto()
406-
imp.version = self._opset
407-
kwargs["opset_imports"] = [imp]
415+
opsets = []
416+
imp = OperatorSetIdProto()
417+
imp.version = self._opset
418+
opsets.append(imp)
419+
if self._extra_opset is not None:
420+
opsets.extend(self._extra_opset)
421+
kwargs["opset_imports"] = opsets
408422

409423
model_proto = helper.make_model(graph, **kwargs)
410424

tf2onnx/tfonnx.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,10 +705,24 @@ def expanddims_op7(ctx, node, name, args):
705705

706706

707707
def stridedslice_op(ctx, node, name, args):
708-
# T output = StridedSlice(T input, Index begin, Index end, Index strides,
709-
# @type Index, @int begin_mask, @int end_mask, @int ellipsis_mask,
710-
# @int new_axis_mask, @int shrink_axis_mask)
711-
raise ValueError("StridedSlice not implemented")
708+
# only the cases strides=1 can be mapped to onnx
709+
not_supported_attr = ["begin_mask", "ellipsis_mask", "end_mask", "new_axis_mask", "shrink_axis_mask"]
710+
for attr_name in not_supported_attr:
711+
attr = node.get_attr(attr_name)
712+
if attr is not None and attr.i != 0:
713+
raise ValueError("StridedSlice: attribute " + attr_name + " must be 0")
714+
begin = node.inputs[1].get_tensor_value()
715+
end = node.inputs[2].get_tensor_value()
716+
strides = node.inputs[3].get_tensor_value()[0]
717+
if strides != 1:
718+
raise ValueError("StridedSlice: only strides=1 is supported")
719+
node.set_attr("starts", list(begin))
720+
node.set_attr("ends", list(end))
721+
node.type = "Slice"
722+
ctx.remove_input(node, node.input[3])
723+
ctx.remove_input(node, node.input[2])
724+
ctx.remove_input(node, node.input[1])
725+
return node
712726

713727

714728
def pow_op(ctx, node, name, args):
@@ -765,6 +779,14 @@ def multinomial_op(ctx, node, name, args):
765779
return node
766780

767781

782+
def topk_op(ctx, node, name, args):
783+
k = node.inputs[1].get_tensor_value()
784+
node.set_attr("k", k[0])
785+
node.type = "TopK"
786+
ctx.remove_input(node, node.input[1])
787+
return node
788+
789+
768790
# pylint: enable=W0613,C0111,W0612
769791

770792
# map tensorflow ops to onnx ops. The format below is
@@ -847,6 +869,7 @@ def multinomial_op(ctx, node, name, args):
847869
"Sum": (reduce_op, ["ReduceSum"]),
848870
"Tanh": (direct_op, []),
849871
"Transpose": (transpose_op, []),
872+
"TopKV2": (topk_op, []),
850873
}
851874

852875
_OPSET_5 = {
@@ -1078,7 +1101,7 @@ def tf_optimize(sess, inputs, outputs, graph_def):
10781101

10791102

10801103
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
1081-
opset=None, custom_op_handlers=None, custom_rewriter=None):
1104+
opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None):
10821105
"""Convert tensorflow graph to onnx graph.
10831106
Args:
10841107
tf_graph: tensorflow graph
@@ -1106,7 +1129,7 @@ def topological_sort(ops):
11061129

11071130
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph)
11081131

1109-
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset)
1132+
g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset)
11101133
ops = g.get_nodes()
11111134

11121135
# rewrite graph

tf2onnx/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
'min', 'seed', 'ends', 'paddings', 'to', 'gamma', 'width_scale', 'normalize_variance', 'group', 'ratio', 'values',
7676
'dtype', 'output_shape', 'spatial', 'split', 'input_forget', 'keepdims', 'transA', 'auto_pad', 'border', 'low',
7777
'linear_before_reset', 'height_scale', 'output_padding', 'shape', 'kernel_shape', 'epsilon', 'size', 'starts',
78-
'direction', 'max', 'clip', 'across_channels', 'value', 'strides', 'extra_shape', 'scales'
78+
'direction', 'max', 'clip', 'across_channels', 'value', 'strides', 'extra_shape', 'scales', 'k'
7979
}
8080

8181

0 commit comments

Comments
 (0)