Skip to content

Commit cd77071

Browse files
committed
add ms range_op1
1 parent d064ff0 commit cd77071

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

tf2onnx/custom/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
""" custom tf2onnx mapping functions. """
4+
5+
from . import ms

tf2onnx/custom/ms.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
""" tf2onnx mapping functions for ms domain. """
4+
5+
from onnx.onnx_pb import TensorProto
6+
from tf2onnx import constants, utils
7+
from tf2onnx.function.range import make_range_const
8+
9+
10+
# pylint: disable=unused-argument
11+
12+
def range_op1(ctx, node, name, args):
13+
"""Range."""
14+
# T range = Range(T start, T limit, T delta)
15+
dtype = node.get_attr_int("Tidx")
16+
shape = node.output_shapes[0]
17+
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
18+
ctx.remove_node(node.name)
19+
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], name, shape, dtype)
20+
21+
22+
def make_range(ctx, start, limit, delta, output, scope_name, shape, dtype):
23+
if all(ctx.get_node_by_output(n).is_const() for n in [start, limit, delta]) is True:
24+
make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
25+
else:
26+
_make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
27+
28+
29+
def _make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype):
30+
utils.make_sure(
31+
dtype in [TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.INT16, TensorProto.INT32, TensorProto.INT64],
32+
"dtype %s is not supported", dtype)
33+
ctx.make_node("Range", [start, limit, delta], outputs=[output], name=scope_name, shapes=[shape], dtypes=[dtype],
34+
domain=constants.MICROSOFT_DOMAIN)
35+
36+
37+
_OPSET_1 = {
38+
"Range": (range_op1, []),
39+
}
40+
41+
OPSETS = [
42+
(1, _OPSET_1),
43+
]

0 commit comments

Comments
 (0)