Skip to content

Commit 8c72723

Browse files
authored
Merge pull request #481 from zhijxu-MS/push_branch
support resize, nonmaxsuppression in opset 10; and fix some bugs.
2 parents 022b718 + f97b4d8 commit 8c72723

File tree

6 files changed

+88
-13
lines changed

6 files changed

+88
-13
lines changed

tests/test_backend.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,10 @@ def test_resize_nearest_neighbor(self):
15291529
_ = tf.identity(x_, name=_TFOUTPUT)
15301530
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
15311531
if self.config.opset >= 9:
1532-
scale_node = group_nodes_by_type(graph)["Upsample"][0].inputs[1]
1532+
# in opset 10, upsample is removed and resize is defined.
1533+
node_statistic = group_nodes_by_type(graph)
1534+
mapped_node = (node_statistic.get("Upsample") or node_statistic.get("Resize"))[0]
1535+
scale_node = mapped_node.inputs[1]
15331536
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 2.0, 2.0]))
15341537

15351538
@check_opset_min_version(9, "resize_nearest_neighbor")
@@ -1557,7 +1560,10 @@ def test_resize_bilinear(self):
15571560
_ = tf.identity(x_, name=_TFOUTPUT)
15581561
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
15591562
if self.config.opset >= 9:
1560-
scale_node = group_nodes_by_type(graph)["Upsample"][0].inputs[1]
1563+
# in opset 10, upsample is removed and resize is defined.
1564+
node_statistic = group_nodes_by_type(graph)
1565+
mapped_node = (node_statistic.get("Upsample") or node_statistic.get("Resize"))[0]
1566+
scale_node = mapped_node.inputs[1]
15611567
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 2.0, 2.0]))
15621568

15631569
@check_opset_min_version(9, "resize_bilinear")
@@ -1573,6 +1579,35 @@ def test_resize_bilinear_with_non_const(self):
15731579
_ = tf.identity(x_, name=_TFOUTPUT)
15741580
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
15751581

1582+
@check_opset_min_version(10, "resize scale can less than 1")
1583+
def test_resize_bilinear_with_non_const2(self):
1584+
# scales has an element larger than 1 and also has an element less that 1
1585+
x_shape = [3, 100, 8, 5]
1586+
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
1587+
x = tf.placeholder(tf.float32, x_shape, name=_TFINPUT)
1588+
1589+
x_new_size = np.array([20, 16]).astype(np.int32)
1590+
x_new_size_ = tf.placeholder(shape=[None], dtype=tf.int32, name=_TFINPUT1)
1591+
1592+
x_ = tf.image.resize_bilinear(x, x_new_size_)
1593+
_ = tf.identity(x_, name=_TFOUTPUT)
1594+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
1595+
1596+
@check_opset_min_version(10, "resize scale can less than 1")
1597+
def test_resize_nearest_neighbor2(self):
1598+
x_shape = [1, 300, 20, 2]
1599+
x_new_size = [30, 40]
1600+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
1601+
x = tf.placeholder(tf.float32, x_shape, name=_TFINPUT)
1602+
x_new_size_ = tf.constant(x_new_size)
1603+
x_ = tf.image.resize_nearest_neighbor(x, x_new_size_)
1604+
_ = tf.identity(x_, name=_TFOUTPUT)
1605+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1606+
node_statistic = group_nodes_by_type(graph)
1607+
mapped_node = node_statistic.get("Resize")[0]
1608+
scale_node = mapped_node.inputs[1]
1609+
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 0.1, 2.0]))
1610+
15761611
@check_opset_min_version(9, "fill")
15771612
def test_fill_float32(self):
15781613
x_shape = [1, 15, 20, 2]
@@ -2117,5 +2152,19 @@ def test_isinf(self):
21172152
self._run_test_case([_OUTPUT], {_INPUT: x_val})
21182153
tf.reset_default_graph()
21192154

2155+
@check_opset_min_version(10, "NonMaxSuppression")
2156+
def test_non_max_suppression(self):
2157+
box_num = 10
2158+
boxes_val = np.random.random_sample([box_num, 4]).astype(np.float32)
2159+
scores_val = np.random.random_sample([box_num]).astype(np.float32)
2160+
boxes = tf.placeholder(tf.float32, shape=[None, 4], name=_TFINPUT)
2161+
scores = tf.placeholder(tf.float32, shape=[None], name=_TFINPUT1)
2162+
res1 = tf.image.non_max_suppression(boxes, scores, max_output_size=int(box_num / 2))
2163+
res2 = tf.image.non_max_suppression(boxes, scores, max_output_size=0)
2164+
_ = tf.identity(res1, name=_TFOUTPUT)
2165+
_ = tf.identity(res2, name=_TFOUTPUT1)
2166+
self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
2167+
2168+
21202169
if __name__ == '__main__':
21212170
unittest_main()

tf2onnx/handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# pylint: disable=unused-argument,missing-docstring,invalid-name
1717

18+
1819
class tf_op:
1920
"""Class to implement the decorator to register handlers that map tf to onnx."""
2021

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,15 @@ def version_7(cls, ctx, node, **kwargs):
464464

465465
@classmethod
466466
def version_9(cls, ctx, node, **kwargs):
467-
cls._convert_since_9(ctx, node, **kwargs)
467+
cls._convert_since_9(ctx, node, op_type="Upsample")
468468

469469
@classmethod
470470
def version_10(cls, ctx, node, **kwargs):
471-
cls._convert_since_9(ctx, node, **kwargs)
471+
cls._convert_since_9(ctx, node, op_type="Resize")
472472

473473
@classmethod
474-
def _convert_since_9(cls, ctx, node, **kwargs):
474+
def _convert_since_9(cls, ctx, node, op_type):
475+
475476
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
476477
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
477478
# wants the input to be NHWC - adjust target_shape to this.
@@ -505,7 +506,7 @@ def _convert_since_9(cls, ctx, node, **kwargs):
505506
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
506507
# because onnxruntime only supports to scale the last two dims so transpose is inserted
507508
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]})
508-
upsample = ctx.make_node("Upsample", [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
509+
upsample = ctx.make_node(op_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
509510

510511
shapes = node.output_shapes
511512
dtypes = node.output_dtypes

tf2onnx/onnx_opset/tensor.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
# pylint: disable=unused-argument,missing-docstring,unused-variable
2626

27+
2728
def _convert_shapenode_to_int64(ctx, node, input_number):
2829
"""cast int32 shape into int64 shape."""
2930
name = node.input[input_number]
@@ -55,6 +56,7 @@ def _wrap_concat_with_cast(ctx, node):
5556
ctx.set_dtype(output_cast.output[0], dtype)
5657
ctx.copy_shape(output_name, output_cast.output[0])
5758

59+
5860
@tf_op(["Size", "Flatten", "Dropout"])
5961
class DirectOp:
6062
@classmethod
@@ -296,8 +298,6 @@ def version_4(cls, ctx, node, **kwargs):
296298
ctx.remove_input(node, node.input[2])
297299
node.set_attr("axis", axis)
298300

299-
INT64_MAX = np.iinfo(np.int64).max
300-
301301

302302
def _make_gathernd_inner_loop(ctx, params, index, dtype):
303303
"""create the inner loop for GatherNd."""
@@ -343,7 +343,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
343343
# reshape indices into [sum(indices[:-1]), indices[-1]]
344344
indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
345345
indices_size = ctx.make_node("Size", [indices])
346-
attr = {"axes": [0], "ends": [INT64_MAX], "starts": [-1]}
346+
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [-1]}
347347
inputs_map = {"data": indices_shape.output[0], **attr}
348348
inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
349349
outter_shape = ctx.make_node("Div",
@@ -401,7 +401,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
401401
[inner_loop_shape.output[0], one_const.output[0]],
402402
attr={"axis": 0},
403403
dtypes=[TensorProto.INT64])
404-
attr = {"axes": [0], "ends": [INT64_MAX], "starts": [1]}
404+
attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [1]}
405405
inputs_map = {"data": inner_loop_shape_.output[0], **attr}
406406
output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
407407
attr = {"axes": [0], "ends": [-1], "starts": [0]}
@@ -932,3 +932,29 @@ def version_10(cls, ctx, node, **kwargs):
932932
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
933933
if node_dtype not in [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
934934
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
935+
936+
937+
@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3"], onnx_op="NonMaxSuppression")
938+
class NonMaxSuppression:
939+
@classmethod
940+
def version_10(cls, ctx, node, **kwargs):
941+
# int32 = NonMaxSuppressionV2(T boxes, T scores, int32 max_output_size, T iou_threshold, T score_threshold)
942+
# int64 = NonMaxSuppression(T boxes, T scores, int64 max_output_size, T iou_threshold, T score_threshold),
943+
# T means float32 here, the last 3 params are optional
944+
# tf boxes is 2D ([boxes_num, 4]) while onnx is 3D ([num_batches, boxes_num, 4])
945+
# tf scores is 1D ([boxes_num])while onnx is 2D ([num_batches, num_classes, boxes_num])
946+
# onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
947+
# while tf's output is [num_selected_boxes]
948+
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
949+
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
950+
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64)
951+
# replace original node with nonmaxsurppress + slice + squeeze + cast
952+
dtypes = [ctx.get_dtype(node.output[0])]
953+
shapes = [ctx.get_shape(node.output[0])]
954+
ctx.remove_node(node.name)
955+
new_nonmaxsurppress = ctx.make_node(node.type, node.input).output[0]
956+
slice_op = GraphBuilder(ctx).make_slice({"data": new_nonmaxsurppress,
957+
"axes": [1], "ends": [3], "starts": [2]})
958+
squeeze_op = ctx.make_node("Squeeze", [slice_op], attr={"axes": [1]})
959+
ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32},
960+
name=node.name, outputs=node.output, dtypes=dtypes, shapes=shapes)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def _shape_handler(self, trans, node):
479479
self._g.remove_node(trans.name)
480480
self._g.remove_node(node.name)
481481
shape_node = self._g.make_node("Shape", [trans.input[0]])
482-
const_node = self._g.make_const("Const", np.array(trans.get_attr("perm").ints))
482+
const_node = self._g.make_const(utils.make_name("Const"), np.array(trans.get_attr("perm").ints))
483483
gather_node = self._g.make_node("Gather", [shape_node.output[0], const_node.output[0]], outputs=node.output)
484484
self._g.set_shape(gather_node.output[0], output_shape)
485485
self._g.set_dtype(gather_node.output[0], output_dtype)

tf2onnx/tfonnx.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,7 @@ def rewrite_incomplete_type_support_rs5(g, ops):
471471
def rewrite_incomplete_type_support_rs6(g, ops):
472472
impacted_ops = [
473473
"Div",
474-
"Greater",
475474
"IsNaN",
476-
"Less",
477475
"Max",
478476
"Min",
479477
"ReduceSum",

0 commit comments

Comments
 (0)