Skip to content

Commit 5f3dd1f

Browse files
authored
Merge pull request #857 from PreethaVeera/Preetha/crop_and_resize_opset10
Add support for crop_and_resize for opset 10.
2 parents fdabeea + 3863d32 commit 5f3dd1f

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

tests/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# Copyright (c) Microsoft Corporation. All rights reserved.
23
# Licensed under the MIT license.
34

@@ -2488,6 +2489,19 @@ def func(x, z):
24882489
return space_to_batch_nd(x, y, z, name=_TFOUTPUT)
24892490
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT2: z_val})
24902491

2492+
@check_opset_min_version(10, "CropAndResize")
2493+
def test_crop_and_resize(self):
2494+
boxes_val = [[0.5, 0.7, 0.7, 0.9], [0.2, 0.4, 0.4, 0.6]]
2495+
def func(input_x, box_ind):
2496+
boxes = tf.constant(boxes_val, dtype=tf.float32)
2497+
corp_size = tf.constant(np.array([20, 20]).astype(np.int32))
2498+
return tf.image.crop_and_resize(input_x, boxes, box_ind, corp_size, name=_TFOUTPUT, method='bilinear')
2499+
2500+
input_x_val = np.random.randint(low=0, high=256, size=[2, 36, 36, 3]).astype(np.float32) # NHWC
2501+
box_ind_val = np.array([1, 0]).astype(np.int32)
2502+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x_val, _INPUT2: box_ind_val},
2503+
rtol=1e-04, atol=1e-03)
2504+
24912505
@check_opset_min_version(11, "CropAndResize")
24922506
def test_crop_and_resize_linear(self):
24932507
def func(input_x, boxes, box_ind, corp_size):

tf2onnx/onnx_opset/nn.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,48 @@ def version_11(cls, ctx, node, **kwargs):
589589

590590
@tf_op(["CropAndResize"])
591591
class CropAndResize:
592+
@classmethod
593+
def version_10(cls, ctx, node, **kwargs):
594+
utils.make_sure(node.inputs[1].type == "Const", "boxes input must be a Const")
595+
utils.make_sure(node.inputs[3].type == "Const", "boxes input must be a Const")
596+
name = node.name
597+
output_height = node.inputs[3].get_tensor_value()[0]
598+
output_width = node.inputs[3].get_tensor_value()[1]
599+
rois = node.inputs[1].get_tensor_value()
600+
rois_shape = ctx.get_shape(node.input[1])
601+
img_shape = ctx.get_shape(node.input[0])
602+
transform_rois = np.zeros(list(rois_shape), dtype=np.float32)
603+
for i in range(rois_shape[0]):
604+
y1, x1, y2, x2 = rois[i]
605+
y1 = y1 * (img_shape[1] - 1)
606+
y2 = y2 * (img_shape[1] - 1)
607+
x1 = x1 * (img_shape[2] - 1)
608+
x2 = x2 * (img_shape[2] - 1)
609+
spacing_h = (y2 - y1)
610+
spacing_w = (x2 - x1)
611+
b1 = y1 - 0.5 * spacing_h / (output_height - 1)
612+
a1 = x1 - 0.5 * spacing_w / (output_width - 1)
613+
b2 = y2 + 0.5 * spacing_h / (output_height - 1)
614+
a2 = x2 + 0.5 * spacing_w / (output_width - 1)
615+
transform_rois[i][0] = a1
616+
transform_rois[i][1] = b1
617+
transform_rois[i][2] = a2
618+
transform_rois[i][3] = b2
619+
cast_node = ctx.make_node("Cast", [node.input[2]], attr={"to": onnx_pb.TensorProto.INT64})
620+
bbox_node = ctx.make_const(utils.make_name("bbox"), transform_rois)
621+
dtypes = [ctx.get_dtype(node.output[0])]
622+
shapes = [ctx.get_shape(node.output[0])]
623+
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]},
624+
name=utils.make_name(node.name))
625+
crop_and_resize = ctx.make_node("RoiAlign", inputs=[input_nchw.output[0], bbox_node.output[0],
626+
cast_node.output[0]],
627+
attr={"output_height": output_height, "output_width": output_width,
628+
"spatial_scale": 1.0, "sampling_ratio": 1},
629+
name=utils.make_name(node.name), dtypes=dtypes, shapes=shapes)
630+
ctx.remove_node(name)
631+
res = ctx.make_node("Transpose", crop_and_resize.output, {"perm": [0, 2, 3, 1]},
632+
name=name, outputs=node.output, shapes=shapes, dtypes=dtypes)
633+
592634
@classmethod
593635
def version_11(cls, ctx, node, **kwargs):
594636
# create loop of resize to cater to tensorflow CropAndResize, one box one iteration

0 commit comments

Comments
 (0)