Skip to content

Commit a4cc17f

Browse files
committed
support resize-10
1 parent ca4087e commit a4cc17f

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

tests/test_backend.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,10 @@ def test_resize_nearest_neighbor(self):
15291529
_ = tf.identity(x_, name=_TFOUTPUT)
15301530
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
15311531
if self.config.opset >= 9:
1532-
scale_node = group_nodes_by_type(graph)["Upsample"][0].inputs[1]
1532+
# in opset 10, upsample is removed and resize is defined.
1533+
node_statistic = group_nodes_by_type(graph)
1534+
mapped_node = (node_statistic.get("Upsample") or node_statistic.get("Resize"))[0]
1535+
scale_node = mapped_node.inputs[1]
15331536
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 2.0, 2.0]))
15341537

15351538
@check_opset_min_version(9, "resize_nearest_neighbor")
@@ -1557,7 +1560,10 @@ def test_resize_bilinear(self):
15571560
_ = tf.identity(x_, name=_TFOUTPUT)
15581561
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
15591562
if self.config.opset >= 9:
1560-
scale_node = group_nodes_by_type(graph)["Upsample"][0].inputs[1]
1563+
# in opset 10, upsample is removed and resize is defined.
1564+
node_statistic = group_nodes_by_type(graph)
1565+
mapped_node = (node_statistic.get("Upsample") or node_statistic.get("Resize"))[0]
1566+
scale_node = mapped_node.inputs[1]
15611567
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 2.0, 2.0]))
15621568

15631569
@check_opset_min_version(9, "resize_bilinear")
@@ -1573,6 +1579,35 @@ def test_resize_bilinear_with_non_const(self):
15731579
_ = tf.identity(x_, name=_TFOUTPUT)
15741580
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
15751581

1582+
@check_opset_min_version(10, "resize scale can less than 1")
1583+
def test_resize_bilinear_with_non_const2(self):
1584+
# scales has an element larger than 1 and also has an element less that 1
1585+
x_shape = [3, 100, 8, 5]
1586+
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
1587+
x = tf.placeholder(tf.float32, x_shape, name=_TFINPUT)
1588+
1589+
x_new_size = np.array([20, 16]).astype(np.int32)
1590+
x_new_size_ = tf.placeholder(shape=[None], dtype=tf.int32, name=_TFINPUT1)
1591+
1592+
x_ = tf.image.resize_bilinear(x, x_new_size_)
1593+
_ = tf.identity(x_, name=_TFOUTPUT)
1594+
self._run_test_case([_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
1595+
1596+
@check_opset_min_version(10, "resize scale can less than 1")
1597+
def test_resize_nearest_neighbor2(self):
1598+
x_shape = [1, 300, 20, 2]
1599+
x_new_size = [30, 40]
1600+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
1601+
x = tf.placeholder(tf.float32, x_shape, name=_TFINPUT)
1602+
x_new_size_ = tf.constant(x_new_size)
1603+
x_ = tf.image.resize_nearest_neighbor(x, x_new_size_)
1604+
_ = tf.identity(x_, name=_TFOUTPUT)
1605+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1606+
node_statistic = group_nodes_by_type(graph)
1607+
mapped_node = node_statistic.get("Resize")[0]
1608+
scale_node = mapped_node.inputs[1]
1609+
self.assertTrue(validate_const_node(scale_node, [1.0, 1.0, 0.1, 2.0]))
1610+
15761611
@check_opset_min_version(9, "fill")
15771612
def test_fill_float32(self):
15781613
x_shape = [1, 15, 20, 2]

tf2onnx/onnx_opset/nn.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,15 @@ def version_7(cls, ctx, node, **kwargs):
464464

465465
@classmethod
466466
def version_9(cls, ctx, node, **kwargs):
467-
cls._convert_since_9(ctx, node, **kwargs)
467+
cls._convert_since_9(ctx, node, node_type="Upsample")
468468

469469
@classmethod
470470
def version_10(cls, ctx, node, **kwargs):
471-
cls._convert_since_9(ctx, node, **kwargs)
471+
cls._convert_since_9(ctx, node, node_type="Resize")
472472

473473
@classmethod
474-
def _convert_since_9(cls, ctx, node, **kwargs):
474+
def _convert_since_9(cls, ctx, node, node_type):
475+
475476
# float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
476477
# https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
477478
# wants the input to be NHWC - adjust target_shape to this.
@@ -505,7 +506,7 @@ def _convert_since_9(cls, ctx, node, **kwargs):
505506
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
506507
# because onnxruntime only supports to scale the last two dims so transpose is inserted
507508
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]})
508-
upsample = ctx.make_node("Upsample", [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
509+
upsample = ctx.make_node(node_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})
509510

510511
shapes = node.output_shapes
511512
dtypes = node.output_dtypes

tf2onnx/onnx_opset/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def version_10(cls, ctx, node, **kwargs):
948948
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
949949
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
950950
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64)
951-
# replace original node with nonmaxsurppress + slice + squeeze +cast
951+
# replace original node with nonmaxsurppress + slice + squeeze + cast
952952
dtypes = [ctx.get_dtype(node.output[0])]
953953
shapes = [ctx.get_shape(node.output[0])]
954954
ctx.remove_node(node.name)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def _shape_handler(self, trans, node):
479479
self._g.remove_node(trans.name)
480480
self._g.remove_node(node.name)
481481
shape_node = self._g.make_node("Shape", [trans.input[0]])
482-
const_node = self._g.make_const("Const", np.array(trans.get_attr("perm").ints))
482+
const_node = self._g.make_const(utils.make_name("Const"), np.array(trans.get_attr("perm").ints))
483483
gather_node = self._g.make_node("Gather", [shape_node.output[0], const_node.output[0]], outputs=node.output)
484484
self._g.set_shape(gather_node.output[0], output_shape)
485485
self._g.set_dtype(gather_node.output[0], output_dtype)

0 commit comments

Comments
 (0)