Skip to content

Commit 757ed9a

Browse files
authored
Merge pull request #443 from onnx/gs/code-reorg
reorg handlers in tfonnx.py into category based files
2 parents 11c9246 + 68e6abb commit 757ed9a

27 files changed

+3183
-2785
lines changed

tests/run_pretrained_models.yaml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,6 @@ mobilenet_v1_75_192:
185185
outputs:
186186
- MobilenetV1/Predictions/Softmax:0
187187

188-
tiny-yolo:
189-
# works but local file
190-
disabled: true
191-
model: c:/src/darkflow/built_graph/tiny-yolo.pb
192-
input_get: get_beach
193-
inputs:
194-
"input:0": [1, 416, 416, 3]
195-
outputs:
196-
- output:0
197-
rtol: 0.6
198-
199188
nasnet-a_mobile_224:
200189
# has only checkpoint format
201190
disabled: true

tests/test_graph.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
from tf2onnx.graph import GraphUtil
2323
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2424
from tf2onnx.tfonnx import process_tf_graph
25+
from tf2onnx.handler import tf_op
26+
2527
from common import get_test_config, unittest_main
2628

2729

28-
# pylint: disable=missing-docstring
30+
# pylint: disable=missing-docstring,unused-argument,unused-variable
2931

3032
def onnx_to_graphviz(g, include_attrs=False):
3133
"""Return dot for graph."""
@@ -331,10 +333,10 @@ def rewrite_test(g, ops):
331333
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
332334
onnx_to_graphviz(g))
333335

334-
def test_custom_op(self):
335-
"""Custom op test."""
336+
def test_custom_op_depreciated(self):
337+
"""Custom op test using old depreciated api."""
336338

337-
def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
339+
def print_handler(ctx, node, name, args):
338340
# replace tf.Print() with Identity
339341
# T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
340342
# becomes:
@@ -360,6 +362,32 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
360362
self.assertEqual(g.opset, self.config.opset)
361363
self.assertEqual(g.extra_opset, [constants.TENSORFLOW_OPSET])
362364

365+
def test_custom_op(self):
366+
"""Custom op test."""
367+
368+
@tf_op("Print", onnx_op="Identity")
369+
class Print:
370+
@classmethod
371+
def version_1(cls, ctx, node, **kwargs):
372+
self.assertEqual(node.type, "Identity")
373+
node.domain = constants.TENSORFLOW_OPSET.domain
374+
del node.input[1:]
375+
return node
376+
377+
with tf.Session() as sess:
378+
x = tf.placeholder(tf.float32, [2, 3], name="input1")
379+
x_ = tf.Print(x, [x], "hello")
380+
_ = tf.identity(x_, name="output")
381+
g = process_tf_graph(sess.graph,
382+
opset=self.config.opset,
383+
extra_opset=[constants.TENSORFLOW_OPSET])
384+
self.assertEqual(
385+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [domain="ai.onnx.converters.tensorflow" '
386+
'op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
387+
onnx_to_graphviz(g))
388+
self.assertEqual(g.opset, self.config.opset)
389+
self.assertEqual(g.extra_opset, [constants.TENSORFLOW_OPSET])
390+
363391
def test_extra_opset(self):
364392
extra_opset = [
365393
utils.make_opsetid(constants.MICROSOFT_DOMAIN, 1),

tf2onnx/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,8 @@
2727
TARGET_CAFFE2 = "caffe2"
2828
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2]
2929
DEFAULT_TARGET = []
30+
31+
NCHW_TO_NHWC = [0, 2, 3, 1]
32+
NHWC_TO_NCHW = [0, 3, 1, 2]
33+
HWCN_TO_NCHW = [3, 2, 0, 1]
34+
NCHW_TO_HWCN = [2, 3, 1, 0]

tf2onnx/custom_opsets/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,3 @@
33
""" custom tf2onnx mapping functions. """
44

55
from . import ms
6-
from .. import constants
7-
8-
DOMAIN_OPSETS = {
9-
constants.MICROSOFT_DOMAIN: ms.OPSETS
10-
}

tf2onnx/custom_opsets/ms.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,15 @@
44

55
from onnx.onnx_pb import TensorProto
66
from tf2onnx import constants, utils
7-
from tf2onnx.function.range import make_range_const
7+
from tf2onnx.handler import tf_op
8+
from tf2onnx.onnx_opset import controlflow
89

910

10-
# pylint: disable=unused-argument
11-
12-
def range_op1(ctx, node, name, args):
13-
"""Range."""
14-
# T range = Range(T start, T limit, T delta)
15-
dtype = node.get_attr_int("Tidx")
16-
shape = node.output_shapes[0]
17-
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
18-
ctx.remove_node(node.name)
19-
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], name, shape, dtype)
20-
11+
# pylint: disable=unused-argument,missing-docstring
2112

2213
def make_range(ctx, start, limit, delta, output, scope_name, shape, dtype):
2314
if all(ctx.get_node_by_output(n).is_const() for n in [start, limit, delta]) is True:
24-
make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
15+
controlflow.make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
2516
else:
2617
_make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype)
2718

@@ -34,10 +25,14 @@ def _make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, d
3425
domain=constants.MICROSOFT_DOMAIN)
3526

3627

37-
_OPSET_1 = {
38-
"Range": (range_op1, []),
39-
}
40-
41-
OPSETS = [
42-
(1, _OPSET_1),
43-
]
28+
@tf_op("Range", domain=constants.MICROSOFT_DOMAIN)
29+
class Range:
30+
@classmethod
31+
def version_1(cls, ctx, node, **kwargs):
32+
"""Range."""
33+
# T range = Range(T start, T limit, T delta)
34+
dtype = node.get_attr_int("Tidx")
35+
shape = node.output_shapes[0]
36+
utils.make_sure(dtype is not None, "Tidx of %s is None", node.name)
37+
ctx.remove_node(node.name)
38+
make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)

tf2onnx/function/__init__.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

tf2onnx/function/gathernd.py

Lines changed: 0 additions & 153 deletions
This file was deleted.

0 commit comments

Comments
 (0)