Skip to content

Commit 0eaa212

Browse files
committed
reorg habndlers tfonnx.py into category based files
1 parent e8ec761 commit 0eaa212

26 files changed

+3081
-2779
lines changed

tests/run_pretrained_models.yaml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,6 @@ mobilenet_v1_75_192:
185185
outputs:
186186
- MobilenetV1/Predictions/Softmax:0
187187

188-
tiny-yolo:
189-
# works but local file
190-
disabled: true
191-
model: c:/src/darkflow/built_graph/tiny-yolo.pb
192-
input_get: get_beach
193-
inputs:
194-
"input:0": [1, 416, 416, 3]
195-
outputs:
196-
- output:0
197-
rtol: 0.6
198-
199188
nasnet-a_mobile_224:
200189
# has only checkpoint format
201190
disabled: true

tf2onnx/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,8 @@
2727
TARGET_CAFFE2 = "caffe2"
2828
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2]
2929
DEFAULT_TARGET = []
30+
31+
NCHW_TO_NHWC = [0, 2, 3, 1]
32+
NHWC_TO_NCHW = [0, 3, 1, 2]
33+
HWCN_TO_NCHW = [3, 2, 0, 1]
34+
NCHW_TO_HWCN = [2, 3, 1, 0]

tf2onnx/custom_opsets/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,3 @@
44

55
from . import ms
66
from .. import constants
7-
8-
DOMAIN_OPSETS = {
9-
constants.MICROSOFT_DOMAIN: ms.OPSETS
10-
}

tf2onnx/custom_opsets/ms.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,15 @@
44

55
from onnx.onnx_pb import TensorProto
66
from tf2onnx import constants, utils
7-
from tf2onnx.function.range import make_range_const
7+
from tf2onnx.handler import tf_op
8+
from tf2onnx.onnx_opset import controlflow
89

910

10-
# pylint: disable=unused-argument
11-
12-
def range_op1(ctx, node, name, args):
13-
"""Range."""
14-
# T range = Range(T start, T limit, T delta)
15-
dtype = node.get_attr_int("Tidx")
16-
shape = node.output_shapes[0]
17-
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
18-
ctx.remove_node(node.name)
19-
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], name, shape, dtype)
20-
11+
# pylint: disable=unused-argument,missing-docstring
2112

2213
def make_range(ctx, start, limit, delta, output, scope_name, shape, dtype):
2314
if all(ctx.get_node_by_output(n).is_const() for n in [start, limit, delta]) is True:
24-
make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
15+
controlflow.make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
2516
else:
2617
_make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
2718

@@ -34,10 +25,14 @@ def _make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, d
3425
domain=constants.MICROSOFT_DOMAIN)
3526

3627

37-
_OPSET_1 = {
38-
"Range": (range_op1, []),
39-
}
40-
41-
OPSETS = [
42-
(1, _OPSET_1),
43-
]
28+
@tf_op("Range", domain=constants.MICROSOFT_DOMAIN)
29+
class Range:
30+
@classmethod
31+
def version_1(cls, ctx, node, **kwargs):
32+
"""Range."""
33+
# T range = Range(T start, T limit, T delta)
34+
dtype = node.get_attr_int("Tidx")
35+
shape = node.output_shapes[0]
36+
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
37+
ctx.remove_node(node.name)
38+
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)

tf2onnx/function/__init__.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

tf2onnx/function/gathernd.py

Lines changed: 0 additions & 153 deletions
This file was deleted.

0 commit comments

Comments
 (0)