Skip to content

Commit 8b4ca01

Browse files
committed
Merge branch 'master' of github.com:onnx/tensorflow-onnx into add_missing_import
2 parents 3ef9cf8 + 48fdca5 commit 8b4ca01

File tree

5 files changed

+84
-6
lines changed

5 files changed

+84
-6
lines changed

tests/test_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# Copyright (c) Microsoft Corporation. All rights reserved.
23
# Licensed under the MIT license.
34

@@ -2488,6 +2489,19 @@ def func(x, z):
24882489
return space_to_batch_nd(x, y, z, name=_TFOUTPUT)
24892490
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT2: z_val})
24902491

2492+
@check_opset_min_version(10, "CropAndResize")
2493+
def test_crop_and_resize(self):
2494+
boxes_val = [[0.5, 0.7, 0.7, 0.9], [0.2, 0.4, 0.4, 0.6]]
2495+
def func(input_x, box_ind):
2496+
boxes = tf.constant(boxes_val, dtype=tf.float32)
2497+
corp_size = tf.constant(np.array([20, 20]).astype(np.int32))
2498+
return tf.image.crop_and_resize(input_x, boxes, box_ind, corp_size, name=_TFOUTPUT, method='bilinear')
2499+
2500+
input_x_val = np.random.randint(low=0, high=256, size=[2, 36, 36, 3]).astype(np.float32) # NHWC
2501+
box_ind_val = np.array([1, 0]).astype(np.int32)
2502+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x_val, _INPUT2: box_ind_val},
2503+
rtol=1e-04, atol=1e-03)
2504+
24912505
@check_opset_min_version(11, "CropAndResize")
24922506
def test_crop_and_resize_linear(self):
24932507
def func(input_x, boxes, box_ind, corp_size):
@@ -2966,6 +2980,16 @@ def func(input_holder):
29662980
for input_val in input_vals:
29672981
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
29682982

2983+
@check_opset_min_version(8)
2984+
def test_broadcast(self):
2985+
input_tensor_val = np.random.randint(low=0, high=256, size=[2, 3]).astype(np.float32)
2986+
new_shape_val = np.array([3, 2, 3]).astype(np.int64)
2987+
2988+
def func(input_tensor, new_shape):
2989+
return tf.broadcast_to(input_tensor, new_shape, _TFOUTPUT)
2990+
2991+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_tensor_val, _INPUT1: new_shape_val})
2992+
29692993

29702994
if __name__ == '__main__':
29712995
unittest_main()

tf2onnx/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def domain(self, val):
135135
@property
136136
def data_format(self):
137137
"""Return data_format."""
138-
return self.get_attr_str("data_format")
138+
attr_str = self.get_attr_value("data_format")
139+
return "unkown" if attr_str is None else attr_str.decode("utf-8")
139140

140141
@data_format.setter
141142
def data_format(self, val):

tf2onnx/onnx_opset/nn.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,48 @@ def version_11(cls, ctx, node, **kwargs):
589589

590590
@tf_op(["CropAndResize"])
591591
class CropAndResize:
592+
@classmethod
593+
def version_10(cls, ctx, node, **kwargs):
594+
utils.make_sure(node.inputs[1].type == "Const", "boxes input must be a Const")
595+
utils.make_sure(node.inputs[3].type == "Const", "boxes input must be a Const")
596+
name = node.name
597+
output_height = node.inputs[3].get_tensor_value()[0]
598+
output_width = node.inputs[3].get_tensor_value()[1]
599+
rois = node.inputs[1].get_tensor_value()
600+
rois_shape = ctx.get_shape(node.input[1])
601+
img_shape = ctx.get_shape(node.input[0])
602+
transform_rois = np.zeros(list(rois_shape), dtype=np.float32)
603+
for i in range(rois_shape[0]):
604+
y1, x1, y2, x2 = rois[i]
605+
y1 = y1 * (img_shape[1] - 1)
606+
y2 = y2 * (img_shape[1] - 1)
607+
x1 = x1 * (img_shape[2] - 1)
608+
x2 = x2 * (img_shape[2] - 1)
609+
spacing_h = (y2 - y1)
610+
spacing_w = (x2 - x1)
611+
b1 = y1 - 0.5 * spacing_h / (output_height - 1)
612+
a1 = x1 - 0.5 * spacing_w / (output_width - 1)
613+
b2 = y2 + 0.5 * spacing_h / (output_height - 1)
614+
a2 = x2 + 0.5 * spacing_w / (output_width - 1)
615+
transform_rois[i][0] = a1
616+
transform_rois[i][1] = b1
617+
transform_rois[i][2] = a2
618+
transform_rois[i][3] = b2
619+
cast_node = ctx.make_node("Cast", [node.input[2]], attr={"to": onnx_pb.TensorProto.INT64})
620+
bbox_node = ctx.make_const(utils.make_name("bbox"), transform_rois)
621+
dtypes = [ctx.get_dtype(node.output[0])]
622+
shapes = [ctx.get_shape(node.output[0])]
623+
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": [0, 3, 1, 2]},
624+
name=utils.make_name(node.name))
625+
crop_and_resize = ctx.make_node("RoiAlign", inputs=[input_nchw.output[0], bbox_node.output[0],
626+
cast_node.output[0]],
627+
attr={"output_height": output_height, "output_width": output_width,
628+
"spatial_scale": 1.0, "sampling_ratio": 1},
629+
name=utils.make_name(node.name), dtypes=dtypes, shapes=shapes)
630+
ctx.remove_node(name)
631+
res = ctx.make_node("Transpose", crop_and_resize.output, {"perm": [0, 2, 3, 1]},
632+
name=name, outputs=node.output, shapes=shapes, dtypes=dtypes)
633+
592634
@classmethod
593635
def version_11(cls, ctx, node, **kwargs):
594636
# create loop of resize to cater to tensorflow CropAndResize, one box one iteration

tf2onnx/onnx_opset/tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,3 +1806,12 @@ def version_11(cls, ctx, node, **kwargs):
18061806
ctx.remove_node(node.name)
18071807
squeezed_result = ctx.make_node('Squeeze', [gathered_result.output[0]], attr={"axes": [-1]},
18081808
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
1809+
1810+
1811+
@tf_op("BroadcastTo")
1812+
class BroadcastTo:
1813+
@classmethod
1814+
def version_8(cls, ctx, node, **kwargs):
1815+
# broadcast by expanding
1816+
node.type = "Expand"
1817+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -557,11 +557,13 @@ def _pad_handler(self, trans, node):
557557
node.set_attr("pads", new_pads)
558558
return self._switch_transpose_and_node(node, trans)
559559
if node.inputs[1].is_const():
560-
pads = node.inputs[1].get_tensor_value()
561-
# NHWC->NCHW
562-
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
563-
dtype=np.int64)
564-
node.inputs[1].set_tensor_value(new_pads)
560+
if node.inputs[1].data_format in ["NHWC", "unkown"]:
561+
pads = node.inputs[1].get_tensor_value()
562+
# NHWC->NCHW
563+
new_pads = np.array([pads[0], pads[3], pads[1], pads[2], pads[4], pads[7], pads[5], pads[6]],
564+
dtype=np.int64)
565+
node.inputs[1].set_tensor_value(new_pads)
566+
node.inputs[1].data_format = "NCHW"
565567
return self._switch_transpose_and_node(node, trans)
566568
return False
567569

0 commit comments

Comments
 (0)