Skip to content

Commit 4c29b91

Browse files
Tom/combined non max suppression (#1376)
* Implement conversion for one of CombinedNonMaxSuppression Signed-off-by: Tom Wildenhain <[email protected]> * Update tests min version Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 4a7428a commit 4c29b91

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

tests/test_backend.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
_OUTPUT1 = "output1:0"
5252
_TFOUTPUT2 = "output2"
5353
_OUTPUT2 = "output2:0"
54+
_TFOUTPUT3 = "output3"
55+
_OUTPUT3 = "output3:0"
5456

5557

5658
if is_tf2():
@@ -3530,6 +3532,52 @@ def func(boxes, scores):
35303532

35313533
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: boxes_val, _INPUT1: scores_val})
35323534

3535+
@check_tf_min_version("2.3")
3536+
@check_opset_min_version(12, "GatherND with batch_dims")
3537+
def test_combined_non_max_suppression_pad_and_clip(self):
3538+
batch_size = 8
3539+
box_num = 10
3540+
classes_num = 2
3541+
max_total_size = 9
3542+
boxes_val = np.random.random_sample([batch_size, box_num, 1, 4]).astype(np.float32) * 2 - 0.5
3543+
scores_val = np.random.random_sample([batch_size, box_num, classes_num]).astype(np.float32)
3544+
3545+
def func(boxes, scores):
3546+
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = \
3547+
tf.image.combined_non_max_suppression(boxes=boxes, scores=scores, score_threshold=0.1,
3548+
max_output_size_per_class=3, max_total_size=max_total_size,
3549+
iou_threshold=0.5, pad_per_class=True, clip_boxes=True)
3550+
out1 = tf.identity(nmsed_boxes, name=_TFOUTPUT)
3551+
out2 = tf.identity(nmsed_scores, name=_TFOUTPUT1)
3552+
out3 = tf.identity(nmsed_classes, name=_TFOUTPUT2)
3553+
out4 = tf.identity(valid_detections, name=_TFOUTPUT3)
3554+
return out1, out2, out3, out4
3555+
3556+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2, _OUTPUT3], {_INPUT: boxes_val, _INPUT1: scores_val})
3557+
3558+
@check_tf_min_version("2.3")
3559+
@check_opset_min_version(12, "GatherND with batch_dims")
3560+
def test_combined_non_max_suppression_no_pad_no_clip(self):
3561+
batch_size = 8
3562+
box_num = 10
3563+
classes_num = 2
3564+
max_total_size = 9
3565+
boxes_val = np.random.random_sample([batch_size, box_num, 1, 4]).astype(np.float32) * 2 - 0.5
3566+
scores_val = np.random.random_sample([batch_size, box_num, classes_num]).astype(np.float32)
3567+
3568+
def func(boxes, scores):
3569+
nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = \
3570+
tf.image.combined_non_max_suppression(boxes=boxes, scores=scores, score_threshold=0.1,
3571+
max_output_size_per_class=3, max_total_size=max_total_size,
3572+
iou_threshold=0.5, pad_per_class=False, clip_boxes=False)
3573+
out1 = tf.identity(nmsed_boxes, name=_TFOUTPUT)
3574+
out2 = tf.identity(nmsed_scores, name=_TFOUTPUT1)
3575+
out3 = tf.identity(nmsed_classes, name=_TFOUTPUT2)
3576+
out4 = tf.identity(valid_detections, name=_TFOUTPUT3)
3577+
return out1, out2, out3, out4
3578+
3579+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2, _OUTPUT3], {_INPUT: boxes_val, _INPUT1: scores_val})
3580+
35333581
def _conv1d_test(self, x_val, w, stride=None, padding="VALID", rtol=1e-07):
35343582
if stride is None:
35353583
stride = 1

tf2onnx/onnx_opset/tensor.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,144 @@ def version_13(cls, ctx, node, **kwargs):
17611761
cls.any_version(13, ctx, node, **kwargs)
17621762

17631763

1764+
@tf_op(["CombinedNonMaxSuppression"])
1765+
class CombinedNonMaxSuppression:
1766+
@classmethod
1767+
def version_10(cls, ctx, node, **kwargs):
1768+
# boxes.shape = [batch_size, num_boxes, (1 OR num_classes), 4]
1769+
# scores.shape = [batch_size, num_boxes, num_classes]
1770+
boxes, scores, max_per_class, max_total_size, iou_threshold, score_threshold = node.input
1771+
1772+
max_per_class = ctx.make_node("Cast", [max_per_class], attr={'to': TensorProto.INT64}).output[0]
1773+
max_total_size = ctx.make_node("Cast", [max_total_size], attr={'to': TensorProto.INT64}).output[0]
1774+
1775+
pad_per_class = node.get_attr_value("pad_per_class", False)
1776+
clip_boxes = node.get_attr_value("clip_boxes", True)
1777+
shape = ctx.get_shape(boxes)
1778+
share_boxes_across_classes = shape is not None and shape[2] == 1
1779+
utils.make_sure(share_boxes_across_classes,
1780+
"CombinedNonMaxSuppression only currently implemented for boxes shared across classes.")
1781+
1782+
scores_shape = ctx.make_node("Shape", [scores]).output[0]
1783+
# value: [batch_size]
1784+
batch_size = GraphBuilder(ctx).make_slice({'data': scores_shape, 'starts': [0], 'ends': [1], 'axes': [0]})
1785+
1786+
num_classes = GraphBuilder(ctx).make_slice({'data': scores_shape, 'starts': [2], 'ends': [3], 'axes': [0]})
1787+
max_per_class_times_classes = ctx.make_node("Mul", [max_per_class, num_classes]).output[0]
1788+
1789+
const_zero_float = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.float32)).output[0]
1790+
const_one_float = ctx.make_const(utils.make_name("const_one"), np.array(1, np.float32)).output[0]
1791+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
1792+
const_neg_one = ctx.make_const(utils.make_name("const_neg_one"), np.array(-1, np.int64)).output[0]
1793+
const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
1794+
1795+
boxes_sq = GraphBuilder(ctx).make_squeeze({'data': boxes, 'axes': [2]})
1796+
# scores_trans.shape = [batch_size, num_classes, num_boxes]
1797+
scores_trans = ctx.make_node("Transpose", [scores], attr={'perm': [0, 2, 1]}).output[0]
1798+
# shape: [num_selected, 3], elts of format [batch_index, class_index, box_index]
1799+
selected_indices = ctx.make_node(
1800+
"NonMaxSuppression", [boxes_sq, scores_trans, max_per_class, iou_threshold, score_threshold],
1801+
op_name_scope=node.name).output[0]
1802+
selected_classes_unsq = GraphBuilder(ctx).make_slice(
1803+
{'data': selected_indices, 'starts': [1], 'ends': [2], 'axes': [1]})
1804+
selected_classes = GraphBuilder(ctx).make_squeeze({'data': selected_classes_unsq, 'axes': [1]})
1805+
# shape: [num_selected]
1806+
selected_scores = ctx.make_node("GatherND", [scores_trans, selected_indices], op_name_scope=node.name).output[0]
1807+
# shape: [num_selected, 1]
1808+
selected_batch_idx = GraphBuilder(ctx).make_slice(
1809+
{'data': selected_indices, 'starts': [0], 'ends': [1], 'axes': [1]})
1810+
selected_box_num = GraphBuilder(ctx).make_slice(
1811+
{'data': selected_indices, 'starts': [2], 'ends': [3], 'axes': [1]})
1812+
combined_box_idx = ctx.make_node("Concat", [selected_batch_idx, selected_box_num], attr={'axis': 1}).output[0]
1813+
selected_boxes_unsq = ctx.make_node("GatherND", [boxes, combined_box_idx], op_name_scope=node.name).output[0]
1814+
# shape: [num_selected, 4]
1815+
selected_boxes = GraphBuilder(ctx).make_squeeze({'data': selected_boxes_unsq, 'axes': [1]})
1816+
1817+
clipped_boxes = selected_boxes
1818+
if clip_boxes:
1819+
clipped_boxes = ctx.make_node('Max', [clipped_boxes, const_zero_float]).output[0]
1820+
clipped_boxes = ctx.make_node('Min', [clipped_boxes, const_one_float]).output[0]
1821+
1822+
# shape: [num_selected]
1823+
batch_idx_sq = GraphBuilder(ctx).make_squeeze({'data': selected_batch_idx, 'axes': [1]})
1824+
# value: [num_selected]
1825+
num_selected = ctx.make_node("Shape", [selected_scores]).output[0]
1826+
num_selected_sq = GraphBuilder(ctx).make_squeeze({'data': num_selected, 'axes': [0]})
1827+
# shape: [num_selected]
1828+
selected_range = ctx.make_node("Range", [const_zero, num_selected_sq, const_one]).output[0]
1829+
1830+
1831+
id_shape = ctx.make_node("Concat", [batch_size, batch_size], attr={'axis': 0}).output[0]
1832+
zero_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[0])
1833+
zeros_of_shape = ctx.make_node("ConstantOfShape", [id_shape], attr={"value": zero_tensor}).output[0]
1834+
# shape: [batch_size, batch_size]
1835+
id_matrix = ctx.make_node("EyeLike", [zeros_of_shape]).output[0]
1836+
# shape: [num_selected, batch_size]
1837+
one_hot_batch_idx = ctx.make_node("Gather", [id_matrix, batch_idx_sq], attr={'axis': 0}).output[0]
1838+
cum_batch_idx = ctx.make_node("CumSum", [one_hot_batch_idx, const_zero], {'exclusive': True}).output[0]
1839+
# shape: [num_selected]
1840+
idx_within_batch = ctx.make_node("GatherND", [cum_batch_idx, selected_batch_idx], attr={'batch_dims': 1},
1841+
op_name_scope=node.name).output[0]
1842+
idx_within_batch_unsq = GraphBuilder(ctx).make_unsqueeze({'data': idx_within_batch, 'axes': [1]})
1843+
combined_idx = ctx.make_node("Concat", [selected_batch_idx, idx_within_batch_unsq], attr={'axis': 1}).output[0]
1844+
1845+
zero_tensor_float = helper.make_tensor("value", TensorProto.FLOAT, dims=[1], vals=[0])
1846+
neg_one_tensor_float = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[-1])
1847+
# value: [batch_size, max_per_class_times_classes]
1848+
results_grid_shape = ctx.make_node(
1849+
"Concat", [batch_size, max_per_class_times_classes], attr={'axis': 0}).output[0]
1850+
scores_by_batch_empty = ctx.make_node(
1851+
"ConstantOfShape", [results_grid_shape], attr={"value": zero_tensor_float}).output[0]
1852+
idx_by_batch_empty = ctx.make_node(
1853+
"ConstantOfShape", [results_grid_shape], attr={"value": neg_one_tensor_float}).output[0]
1854+
1855+
scores_by_batch = ctx.make_node("ScatterND", [scores_by_batch_empty, combined_idx, selected_scores]).output[0]
1856+
idx_by_batch = ctx.make_node("ScatterND", [idx_by_batch_empty, combined_idx, selected_range]).output[0]
1857+
1858+
k_val = ctx.make_node("Min", [max_total_size, max_per_class_times_classes]).output[0]
1859+
1860+
# shape: [batch_size, k_val]
1861+
top_k_vals, top_k_indices = \
1862+
ctx.make_node("TopK", [scores_by_batch, k_val], attr={'axis': 1}, output_count=2).output
1863+
1864+
top_k_selected_indices = ctx.make_node("GatherElements", [idx_by_batch, top_k_indices], attr={'axis': 1},
1865+
op_name_scope=node.name).output[0]
1866+
1867+
target_size = max_total_size
1868+
if pad_per_class:
1869+
target_size = k_val
1870+
1871+
pad_amt = ctx.make_node("Sub", [target_size, k_val]).output[0]
1872+
pads_const = ctx.make_const(utils.make_name("pad_const"), np.array([0, 0, 0], np.int64)).output[0]
1873+
pads = ctx.make_node("Concat", [pads_const, pad_amt], attr={'axis': 0}).output[0]
1874+
1875+
top_scores_pad = ctx.make_node("Pad", [top_k_vals, pads, const_zero_float]).output[0]
1876+
top_indices_pad = ctx.make_node("Pad", [top_k_selected_indices, pads, const_neg_one]).output[0]
1877+
top_indices_increment = ctx.make_node("Add", [top_indices_pad, const_one]).output[0]
1878+
1879+
valid_indices = ctx.make_node("Greater", [top_k_selected_indices, const_neg_one]).output[0]
1880+
valid_indices_int = ctx.make_node("Cast", [valid_indices], attr={'to': TensorProto.INT32}).output[0]
1881+
# shape: [batch_size]
1882+
valid_indices_cnt = GraphBuilder(ctx).make_reduce_sum(
1883+
{"data": valid_indices_int, "axes": [-1], "keepdims": 0, "noop_with_empty_axes": 1})
1884+
1885+
box_pads = ctx.make_const(utils.make_name("pad_const"), np.array([1, 0, 0, 0], np.int64)).output[0]
1886+
class_pads = ctx.make_const(utils.make_name("pad_const"), np.array([1, 0], np.int64)).output[0]
1887+
clipped_boxes_pad = ctx.make_node("Pad", [clipped_boxes, box_pads, const_zero_float]).output[0]
1888+
selected_classes_pad = ctx.make_node("Pad", [selected_classes, class_pads, const_zero]).output[0]
1889+
nmsed_boxes = ctx.make_node("Gather", [clipped_boxes_pad, top_indices_increment], attr={'axis': 0},
1890+
op_name_scope=node.name).output[0]
1891+
nmsed_classes = ctx.make_node("Gather", [selected_classes_pad, top_indices_increment], attr={'axis': 0},
1892+
op_name_scope=node.name).output[0]
1893+
nmsed_classes_float = ctx.make_node("Cast", [nmsed_classes], attr={'to': TensorProto.FLOAT}).output[0]
1894+
1895+
ctx.replace_all_inputs(node.output[0], nmsed_boxes)
1896+
ctx.replace_all_inputs(node.output[1], top_scores_pad)
1897+
ctx.replace_all_inputs(node.output[2], nmsed_classes_float)
1898+
ctx.replace_all_inputs(node.output[3], valid_indices_cnt)
1899+
ctx.remove_node(node.name)
1900+
1901+
17641902
@tf_op("ReverseSequence")
17651903
class ReverseSequence:
17661904
@classmethod

0 commit comments

Comments
 (0)