Skip to content

Commit ec39a9f

Browse files
Fix bug in CropAndResize for empty tensors (#1588)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 507a6aa commit ec39a9f

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

tests/test_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,6 +3791,18 @@ def func(input_x, boxes, box_ind, corp_size):
37913791
{_INPUT: input_x_val, _INPUT1: boxes_val, _INPUT2: box_ind_val, _INPUT3: corp_size_val},
37923792
rtol=1e-04, atol=1e-03)
37933793

3794+
@check_opset_min_version(11, "CropAndResize")
3795+
def test_crop_and_resize_empty_tensor(self):
3796+
def func(input_x, boxes, box_ind, corp_size):
3797+
return tf.image.crop_and_resize(input_x, boxes, box_ind, corp_size, name=_TFOUTPUT, extrapolation_value=1.0)
3798+
input_x_val = np.random.randint(low=0, high=256, size=[0, 36, 36, 3]).astype(np.float32) # NHWC
3799+
boxes_val = np.array([]).astype(np.float32).reshape([0, 4])
3800+
box_ind_val = np.array([]).astype(np.int32)
3801+
corp_size_val = np.array([40, 40]).astype(np.int32)
3802+
self._run_test_case(func, [_OUTPUT],
3803+
{_INPUT: input_x_val, _INPUT1: boxes_val, _INPUT2: box_ind_val, _INPUT3: corp_size_val},
3804+
rtol=1e-04, atol=1e-03)
3805+
37943806
def test_batch_to_space3d(self):
37953807
block_size = [2, 2]
37963808
crop = [[0, 1], [2, 1]]

tf2onnx/onnx_opset/nn.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,10 @@ def build_dynamic_target_size(ctx, transposed_intput, target_hw):
318318
shape_of_transposed_input = ctx.make_node("Shape", [transposed_intput])
319319
first_half_of_shape = GraphBuilder(ctx).make_slice(
320320
{"data": shape_of_transposed_input.output[0], "ends": [2], "starts": [0]})
321-
target_size_int64 = ctx.make_node("Cast", [target_hw], attr={'to': TensorProto.INT64})
321+
if ctx.get_dtype(target_hw) != TensorProto.INT64:
322+
target_hw = ctx.make_node("Cast", [target_hw], attr={'to': TensorProto.INT64}).output[0]
322323
# We build a tensor containing [n c nh nw]
323-
final_target_size = ctx.make_node("Concat", [first_half_of_shape, target_size_int64.output[0]], {'axis': 0})
324+
final_target_size = ctx.make_node("Concat", [first_half_of_shape, target_hw], {'axis': 0})
324325
return final_target_size
325326

326327

@@ -1183,9 +1184,13 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
11831184
"method").s == b"nearest" else "linear"
11841185
extrapolation_value = float(node.get_attr("extrapolation_value", "0").f)
11851186
input_x = node.input[0]
1187+
x_shape = ctx.make_node("Shape", [input_x]).output[0]
1188+
num_channels = GraphBuilder(ctx).make_slice({"data": x_shape, "starts": [3], "ends": [4], "axes": [0]})
11861189
boxes = node.input[1]
11871190
box_ind = node.input[2]
11881191
crop_size = node.input[3]
1192+
if ctx.get_dtype(crop_size) != TensorProto.INT64:
1193+
crop_size = ctx.make_node("Cast", [crop_size], attr={'to': TensorProto.INT64}).output[0]
11891194
trip_name = utils.make_name(node.name + "_i")
11901195
cond_name = utils.make_name(node.name + "_cond")
11911196
cond_out_name = utils.make_name(node.name + "cond_out")
@@ -1233,6 +1238,10 @@ def any_version_after11(cls, opset, ctx, node, **kwargs):
12331238
branches = {"body": g}
12341239
inner_loop = ctx.make_node("Loop", [trip_node.output[0], cond_const.output[0]], name=node.name,
12351240
outputs=node.output, branches=branches)
1241+
const_neg_one = ctx.make_const(utils.make_name("const_neg_one"), np.array([-1], np.int64)).output[0]
1242+
final_shape = ctx.make_node("Concat", [const_neg_one, crop_size, num_channels], attr={'axis': 0}).output[0]
1243+
# This reshape fixes the case when there are no iterations and the scan output is empty.
1244+
ctx.insert_new_node_on_output("Reshape", inner_loop.output[0], inputs=[inner_loop.output[0], final_shape])
12361245

12371246
@classmethod
12381247
def version_11(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)