|
24 | 24 |
|
25 | 25 | # pylint: disable=unused-argument,missing-docstring,unused-variable
|
26 | 26 |
|
| 27 | + |
27 | 28 | def _convert_shapenode_to_int64(ctx, node, input_number):
|
28 | 29 | """cast int32 shape into int64 shape."""
|
29 | 30 | name = node.input[input_number]
|
@@ -55,6 +56,7 @@ def _wrap_concat_with_cast(ctx, node):
|
55 | 56 | ctx.set_dtype(output_cast.output[0], dtype)
|
56 | 57 | ctx.copy_shape(output_name, output_cast.output[0])
|
57 | 58 |
|
| 59 | + |
58 | 60 | @tf_op(["Size", "Flatten", "Dropout"])
|
59 | 61 | class DirectOp:
|
60 | 62 | @classmethod
|
@@ -296,8 +298,6 @@ def version_4(cls, ctx, node, **kwargs):
|
296 | 298 | ctx.remove_input(node, node.input[2])
|
297 | 299 | node.set_attr("axis", axis)
|
298 | 300 |
|
299 |
| -INT64_MAX = np.iinfo(np.int64).max |
300 |
| - |
301 | 301 |
|
302 | 302 | def _make_gathernd_inner_loop(ctx, params, index, dtype):
|
303 | 303 | """create the inner loop for GatherNd."""
|
@@ -343,7 +343,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
|
343 | 343 | # reshape indices into [sum(indices[:-1]), indices[-1]]
|
344 | 344 | indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64])
|
345 | 345 | indices_size = ctx.make_node("Size", [indices])
|
346 |
| - attr = {"axes": [0], "ends": [INT64_MAX], "starts": [-1]} |
| 346 | + attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [-1]} |
347 | 347 | inputs_map = {"data": indices_shape.output[0], **attr}
|
348 | 348 | inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
|
349 | 349 | outter_shape = ctx.make_node("Div",
|
@@ -401,7 +401,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt
|
401 | 401 | [inner_loop_shape.output[0], one_const.output[0]],
|
402 | 402 | attr={"axis": 0},
|
403 | 403 | dtypes=[TensorProto.INT64])
|
404 |
| - attr = {"axes": [0], "ends": [INT64_MAX], "starts": [1]} |
| 404 | + attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [1]} |
405 | 405 | inputs_map = {"data": inner_loop_shape_.output[0], **attr}
|
406 | 406 | output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64])
|
407 | 407 | attr = {"axes": [0], "ends": [-1], "starts": [0]}
|
@@ -932,3 +932,31 @@ def version_10(cls, ctx, node, **kwargs):
|
932 | 932 | utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
|
933 | 933 | if node_dtype not in [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
|
934 | 934 | raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
|
| 935 | + |
| 936 | + |
| 937 | +@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3"], onnx_op="NonMaxSuppression") |
| 938 | +class NonMaxSuppression: |
| 939 | + @classmethod |
| 940 | + def version_10(cls, ctx, node, **kwargs): |
| 941 | + # int32 = NonMaxSuppressionV2(T boxes, T scores, int32 max_output_size, T iou_threshold, T score_threshold) |
| 942 | + # int64 = NonMaxSuppression(T boxes, T scores, int64 max_output_size, T iou_threshold, T score_threshold), |
| 943 | + # T means float32 here, the last 3 params are optional |
| 944 | + # tf boxes is 2D ([boxes_num, 4]) while onnx is 3D ([num_batches, boxes_num, 4]) |
| 945 | + # tf scores is 1D ([boxes_num])while onnx is 2D ([num_batches, num_classes, boxes_num]) |
| 946 | + # onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index] |
| 947 | + # while tf's output is [num_selected_boxes] |
| 948 | + ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0]) |
| 949 | + ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1]) |
| 950 | + ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64) |
| 951 | + # replace original node with nonmaxsurppress + slice + squeeze +cast |
| 952 | + dtypes = [ctx.get_dtype(node.output[0])] |
| 953 | + shapes = [ctx.get_shape(node.output[0])] |
| 954 | + ctx.remove_node(node.name) |
| 955 | + new_nonmaxsurppress = ctx.make_node(node.type, node.input).output[0] |
| 956 | + attr = {"axes": [1], "ends": [3], "starts": [2]} |
| 957 | + inputs_map = {"data": new_nonmaxsurppress, **attr} |
| 958 | + slice_op = GraphBuilder(ctx).make_slice(inputs_map) |
| 959 | + squeeze_op = ctx.make_node("Squeeze", [slice_op], attr={"axes": [1]}) |
| 960 | + ctx.make_node("Cast", inputs=squeeze_op.output, attr={"to": onnx_pb.TensorProto.INT32}, |
| 961 | + name=node.name, outputs=node.output, dtypes=dtypes, shapes=shapes) |
| 962 | + |
0 commit comments