Skip to content

Commit fd21dda

Browse files
committed
restore nn.py
1 parent e6587e9 commit fd21dda

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,67 @@ def version_11(cls, ctx, node, **kwargs):
587587
cls.version_1(ctx, node, **kwargs)
588588

589589

590+
@tf_op(["CropAndResize"])
591+
class CropAndResize:
592+
@classmethod
593+
def version_11(cls, ctx, node, **kwargs):
594+
# create loop of resize to cater to tensorflow CropAndResize, one box one iteration
595+
mode = "nearest" if node.get_attr("method") is not None and node.get_attr(
596+
"method").s == b"nearest" else "linear"
597+
extrapolation_value = float(node.get_attr("extrapolation_value", "0").f)
598+
input_x = node.inputs[0]
599+
boxes = node.inputs[1]
600+
box_ind = node.inputs[2]
601+
crop_size = node.inputs[3]
602+
trip_name = utils.make_name(node.name + "_i")
603+
cond_name = utils.make_name(node.name + "_cond")
604+
cond_out_name = utils.make_name(node.name + "cond_out")
605+
g = ctx.create_new_graph_with_same_config()
606+
g.add_graph_input(trip_name, TensorProto.INT64, [1])
607+
g.add_graph_input(cond_name, TensorProto.BOOL, [])
608+
g.parent_graph = ctx
609+
const_zero = g.make_const(utils.make_name(node.name + "_const_zero"), np.array([0], dtype=np.int32))
610+
const_zero_long = g.make_const(utils.make_name(node.name + "_const_zero_long"), np.array([0], dtype=np.int64))
611+
const_one = g.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int32))
612+
const_one_long = g.make_const(utils.make_name(node.name + "_const_one_long"), np.array([1], dtype=np.int64))
613+
index_end = g.make_node("Add", [trip_name, const_one_long.output[0]])
614+
box_index_from = g.make_node("Slice", [box_ind.output[0], trip_name, index_end.output[0]], name="Slice_a")
615+
box_index_to = g.make_node("Add", [box_index_from.output[0], const_one.output[0]])
616+
target_x = g.make_node("Slice", [input_x.output[0], box_index_from.output[0], box_index_to.output[0],
617+
const_zero.output[0]], name="Slice_b")
618+
transposed_x = g.make_node("Transpose", [target_x.output[0]], attr={'perm': constants.NHWC_TO_NCHW})
619+
const_zero_zero = g.make_const(utils.make_name(node.name + "_const_zero_zero"),
620+
np.array([0, 0], dtype=np.float32))
621+
const_one_one = g.make_const(utils.make_name(node.name + "_const_one_one"),
622+
np.array([1, 1], dtype=np.float32))
623+
const_four = g.make_const(utils.make_name(node.name + "_const_four"), np.array([4], dtype=np.int64))
624+
const_empty_float = g.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32))
625+
box = g.make_node("Slice", [boxes.output[0], trip_name, index_end.output[0], const_zero_long.output[0]],
626+
name="Slice_c")
627+
roi_raw = g.make_node("Reshape", [box.output[0], const_four.output[0]])
628+
roi_raw_first_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [2], "starts": [0]})
629+
roi_raw_second_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [4], "starts": [2]})
630+
roi_concat_1 = g.make_node("Concat", [const_zero_zero.output[0], roi_raw_first_half], attr={'axis': 0})
631+
roi_concat_2 = g.make_node("Concat", [const_one_one.output[0], roi_raw_second_half], attr={'axis': 0})
632+
final_roi = g.make_node("Concat", [roi_concat_1.output[0], roi_concat_2.output[0]], attr={'axis': 0})
633+
final_crop_size = build_dynamic_target_size(g, transposed_x, crop_size)
634+
resized_x = g.make_node("Resize", [transposed_x.output[0], final_roi.output[0], const_empty_float.output[0],
635+
final_crop_size.output[0]],
636+
attr={"mode": mode, "extrapolation_value": extrapolation_value,
637+
"coordinate_transformation_mode": "tf_crop_and_resize"})
638+
recovered_x = g.make_node("Transpose", [resized_x.output[0]], attr={'perm': constants.NCHW_TO_NHWC})
639+
squeeze_x = g.make_node("Squeeze", inputs=[recovered_x.output[0]], attr={"axes": [0]})
640+
g.make_node("Identity", [cond_name], outputs=[cond_out_name])
641+
g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
642+
g.add_graph_output(squeeze_x.output[0], ctx.get_dtype(node.input[0]), [-1, -1, -1])
643+
trip_node = ctx.make_node("Size", [box_ind.output[0]])
644+
cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
645+
ctx.remove_node(node.name)
646+
inner_loop = ctx.make_node("Loop", [trip_node.output[0], cond_const.output[0]], name=node.name,
647+
outputs=node.output)
648+
inner_loop.set_body_graph_as_attr("body", g)
649+
650+
590651
@tf_op(["ResizeBilinear", "ResizeNearestNeighbor"])
591652
class Resize:
592653
@classmethod

0 commit comments

Comments
 (0)