Skip to content

Commit 53bd117

Browse files
committed
switch to contrib cropandresize
1 parent abb5701 commit 53bd117

File tree

2 files changed

+10
-61
lines changed

2 files changed

+10
-61
lines changed

tf2onnx/custom_opsets/onnx_ml.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,13 @@ def version_8(cls, ctx, node, **kwargs):
3939
customer_nodes = ctx.find_output_consumers(table_node.output[0])
4040
if len(customer_nodes) == 0:
4141
ctx.remove_node(table_node.name)
42+
43+
@tf_op(["CropAndResize"])
44+
class CropAndResize:
45+
@classmethod
46+
def version_11(cls, ctx, node, **kwargs):
47+
""" utilize contrib cropandresize """
48+
node.attr['method'].name = 'mode'
49+
node.domain = constants.MICROSOFT_DOMAIN
50+
ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=[0,3,1,2])
51+
ctx.insert_new_node_on_output("Transpose", node.output[0], node.name + '_transposed', None, perm=[0,2,3,1])

tf2onnx/onnx_opset/nn.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -587,67 +587,6 @@ def version_11(cls, ctx, node, **kwargs):
587587
cls.version_1(ctx, node, **kwargs)
588588

589589

590-
@tf_op(["CropAndResize"])
591-
class CropAndResize:
592-
@classmethod
593-
def version_11(cls, ctx, node, **kwargs):
594-
# create loop of resize to cater to tensorflow CropAndResize, one box one iteration
595-
mode = "nearest" if node.get_attr("method") is not None and node.get_attr(
596-
"method").s == b"nearest" else "linear"
597-
extrapolation_value = float(node.get_attr("extrapolation_value", "0").f)
598-
input_x = node.inputs[0]
599-
boxes = node.inputs[1]
600-
box_ind = node.inputs[2]
601-
crop_size = node.inputs[3]
602-
trip_name = utils.make_name(node.name + "_i")
603-
cond_name = utils.make_name(node.name + "_cond")
604-
cond_out_name = utils.make_name(node.name + "cond_out")
605-
g = ctx.create_new_graph_with_same_config()
606-
g.add_graph_input(trip_name, TensorProto.INT64, [1])
607-
g.add_graph_input(cond_name, TensorProto.BOOL, [])
608-
g.parent_graph = ctx
609-
const_zero = g.make_const(utils.make_name(node.name + "_const_zero"), np.array([0], dtype=np.int32))
610-
const_zero_long = g.make_const(utils.make_name(node.name + "_const_zero_long"), np.array([0], dtype=np.int64))
611-
const_one = g.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int32))
612-
const_one_long = g.make_const(utils.make_name(node.name + "_const_one_long"), np.array([1], dtype=np.int64))
613-
index_end = g.make_node("Add", [trip_name, const_one_long.output[0]])
614-
box_index_from = g.make_node("Slice", [box_ind.output[0], trip_name, index_end.output[0]], name="Slice_a")
615-
box_index_to = g.make_node("Add", [box_index_from.output[0], const_one.output[0]])
616-
target_x = g.make_node("Slice", [input_x.output[0], box_index_from.output[0], box_index_to.output[0],
617-
const_zero.output[0]], name="Slice_b")
618-
transposed_x = g.make_node("Transpose", [target_x.output[0]], attr={'perm': constants.NHWC_TO_NCHW})
619-
const_zero_zero = g.make_const(utils.make_name(node.name + "_const_zero_zero"),
620-
np.array([0, 0], dtype=np.float32))
621-
const_one_one = g.make_const(utils.make_name(node.name + "_const_one_one"),
622-
np.array([1, 1], dtype=np.float32))
623-
const_four = g.make_const(utils.make_name(node.name + "_const_four"), np.array([4], dtype=np.int64))
624-
const_empty_float = g.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32))
625-
box = g.make_node("Slice", [boxes.output[0], trip_name, index_end.output[0], const_zero_long.output[0]],
626-
name="Slice_c")
627-
roi_raw = g.make_node("Reshape", [box.output[0], const_four.output[0]])
628-
roi_raw_first_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [2], "starts": [0]})
629-
roi_raw_second_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [4], "starts": [2]})
630-
roi_concat_1 = g.make_node("Concat", [const_zero_zero.output[0], roi_raw_first_half], attr={'axis': 0})
631-
roi_concat_2 = g.make_node("Concat", [const_one_one.output[0], roi_raw_second_half], attr={'axis': 0})
632-
final_roi = g.make_node("Concat", [roi_concat_1.output[0], roi_concat_2.output[0]], attr={'axis': 0})
633-
final_crop_size = build_dynamic_target_size(g, transposed_x, crop_size)
634-
resized_x = g.make_node("Resize", [transposed_x.output[0], final_roi.output[0], const_empty_float.output[0],
635-
final_crop_size.output[0]],
636-
attr={"mode": mode, "extrapolation_value": extrapolation_value,
637-
"coordinate_transformation_mode": "tf_crop_and_resize"})
638-
recovered_x = g.make_node("Transpose", [resized_x.output[0]], attr={'perm': constants.NCHW_TO_NHWC})
639-
squeeze_x = g.make_node("Squeeze", inputs=[recovered_x.output[0]], attr={"axes": [0]})
640-
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
641-
g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
642-
g.add_graph_output(squeeze_x.output[0], ctx.get_dtype(node.input[0]), [-1, -1, -1])
643-
trip_node = ctx.make_node("Size", [box_ind.output[0]])
644-
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
645-
ctx.remove_node(node.name)
646-
inner_loop = ctx.make_node("Loop", [trip_node.output[0], cond_const.output[0]], name=node.name,
647-
outputs=node.output)
648-
inner_loop.set_body_graph_as_attr("body", g)
649-
650-
651590
@tf_op(["ResizeBilinear", "ResizeNearestNeighbor"])
652591
class Resize:
653592
@classmethod

0 commit comments

Comments
 (0)