@@ -1761,6 +1761,144 @@ def version_13(cls, ctx, node, **kwargs):
1761
1761
cls .any_version (13 , ctx , node , ** kwargs )
1762
1762
1763
1763
1764
+ @tf_op (["CombinedNonMaxSuppression" ])
1765
+ class CombinedNonMaxSuppression :
1766
+ @classmethod
1767
+ def version_10 (cls , ctx , node , ** kwargs ):
1768
+ # boxes.shape = [batch_size, num_boxes, (1 OR num_classes), 4]
1769
+ # scores.shape = [batch_size, num_boxes, num_classes]
1770
+ boxes , scores , max_per_class , max_total_size , iou_threshold , score_threshold = node .input
1771
+
1772
+ max_per_class = ctx .make_node ("Cast" , [max_per_class ], attr = {'to' : TensorProto .INT64 }).output [0 ]
1773
+ max_total_size = ctx .make_node ("Cast" , [max_total_size ], attr = {'to' : TensorProto .INT64 }).output [0 ]
1774
+
1775
+ pad_per_class = node .get_attr_value ("pad_per_class" , False )
1776
+ clip_boxes = node .get_attr_value ("clip_boxes" , True )
1777
+ shape = ctx .get_shape (boxes )
1778
+ share_boxes_across_classes = shape is not None and shape [2 ] == 1
1779
+ utils .make_sure (share_boxes_across_classes ,
1780
+ "CombinedNonMaxSuppression only currently implemented for boxes shared across classes." )
1781
+
1782
+ scores_shape = ctx .make_node ("Shape" , [scores ]).output [0 ]
1783
+ # value: [batch_size]
1784
+ batch_size = GraphBuilder (ctx ).make_slice ({'data' : scores_shape , 'starts' : [0 ], 'ends' : [1 ], 'axes' : [0 ]})
1785
+
1786
+ num_classes = GraphBuilder (ctx ).make_slice ({'data' : scores_shape , 'starts' : [2 ], 'ends' : [3 ], 'axes' : [0 ]})
1787
+ max_per_class_times_classes = ctx .make_node ("Mul" , [max_per_class , num_classes ]).output [0 ]
1788
+
1789
+ const_zero_float = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , np .float32 )).output [0 ]
1790
+ const_one_float = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , np .float32 )).output [0 ]
1791
+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , np .int64 )).output [0 ]
1792
+ const_neg_one = ctx .make_const (utils .make_name ("const_neg_one" ), np .array (- 1 , np .int64 )).output [0 ]
1793
+ const_one = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , np .int64 )).output [0 ]
1794
+
1795
+ boxes_sq = GraphBuilder (ctx ).make_squeeze ({'data' : boxes , 'axes' : [2 ]})
1796
+ # scores_trans.shape = [batch_size, num_classes, num_boxes]
1797
+ scores_trans = ctx .make_node ("Transpose" , [scores ], attr = {'perm' : [0 , 2 , 1 ]}).output [0 ]
1798
+ # shape: [num_selected, 3], elts of format [batch_index, class_index, box_index]
1799
+ selected_indices = ctx .make_node (
1800
+ "NonMaxSuppression" , [boxes_sq , scores_trans , max_per_class , iou_threshold , score_threshold ],
1801
+ op_name_scope = node .name ).output [0 ]
1802
+ selected_classes_unsq = GraphBuilder (ctx ).make_slice (
1803
+ {'data' : selected_indices , 'starts' : [1 ], 'ends' : [2 ], 'axes' : [1 ]})
1804
+ selected_classes = GraphBuilder (ctx ).make_squeeze ({'data' : selected_classes_unsq , 'axes' : [1 ]})
1805
+ # shape: [num_selected]
1806
+ selected_scores = ctx .make_node ("GatherND" , [scores_trans , selected_indices ], op_name_scope = node .name ).output [0 ]
1807
+ # shape: [num_selected, 1]
1808
+ selected_batch_idx = GraphBuilder (ctx ).make_slice (
1809
+ {'data' : selected_indices , 'starts' : [0 ], 'ends' : [1 ], 'axes' : [1 ]})
1810
+ selected_box_num = GraphBuilder (ctx ).make_slice (
1811
+ {'data' : selected_indices , 'starts' : [2 ], 'ends' : [3 ], 'axes' : [1 ]})
1812
+ combined_box_idx = ctx .make_node ("Concat" , [selected_batch_idx , selected_box_num ], attr = {'axis' : 1 }).output [0 ]
1813
+ selected_boxes_unsq = ctx .make_node ("GatherND" , [boxes , combined_box_idx ], op_name_scope = node .name ).output [0 ]
1814
+ # shape: [num_selected, 4]
1815
+ selected_boxes = GraphBuilder (ctx ).make_squeeze ({'data' : selected_boxes_unsq , 'axes' : [1 ]})
1816
+
1817
+ clipped_boxes = selected_boxes
1818
+ if clip_boxes :
1819
+ clipped_boxes = ctx .make_node ('Max' , [clipped_boxes , const_zero_float ]).output [0 ]
1820
+ clipped_boxes = ctx .make_node ('Min' , [clipped_boxes , const_one_float ]).output [0 ]
1821
+
1822
+ # shape: [num_selected]
1823
+ batch_idx_sq = GraphBuilder (ctx ).make_squeeze ({'data' : selected_batch_idx , 'axes' : [1 ]})
1824
+ # value: [num_selected]
1825
+ num_selected = ctx .make_node ("Shape" , [selected_scores ]).output [0 ]
1826
+ num_selected_sq = GraphBuilder (ctx ).make_squeeze ({'data' : num_selected , 'axes' : [0 ]})
1827
+ # shape: [num_selected]
1828
+ selected_range = ctx .make_node ("Range" , [const_zero , num_selected_sq , const_one ]).output [0 ]
1829
+
1830
+
1831
+ id_shape = ctx .make_node ("Concat" , [batch_size , batch_size ], attr = {'axis' : 0 }).output [0 ]
1832
+ zero_tensor = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [0 ])
1833
+ zeros_of_shape = ctx .make_node ("ConstantOfShape" , [id_shape ], attr = {"value" : zero_tensor }).output [0 ]
1834
+ # shape: [batch_size, batch_size]
1835
+ id_matrix = ctx .make_node ("EyeLike" , [zeros_of_shape ]).output [0 ]
1836
+ # shape: [num_selected, batch_size]
1837
+ one_hot_batch_idx = ctx .make_node ("Gather" , [id_matrix , batch_idx_sq ], attr = {'axis' : 0 }).output [0 ]
1838
+ cum_batch_idx = ctx .make_node ("CumSum" , [one_hot_batch_idx , const_zero ], {'exclusive' : True }).output [0 ]
1839
+ # shape: [num_selected]
1840
+ idx_within_batch = ctx .make_node ("GatherND" , [cum_batch_idx , selected_batch_idx ], attr = {'batch_dims' : 1 },
1841
+ op_name_scope = node .name ).output [0 ]
1842
+ idx_within_batch_unsq = GraphBuilder (ctx ).make_unsqueeze ({'data' : idx_within_batch , 'axes' : [1 ]})
1843
+ combined_idx = ctx .make_node ("Concat" , [selected_batch_idx , idx_within_batch_unsq ], attr = {'axis' : 1 }).output [0 ]
1844
+
1845
+ zero_tensor_float = helper .make_tensor ("value" , TensorProto .FLOAT , dims = [1 ], vals = [0 ])
1846
+ neg_one_tensor_float = helper .make_tensor ("value" , TensorProto .INT64 , dims = [1 ], vals = [- 1 ])
1847
+ # value: [batch_size, max_per_class_times_classes]
1848
+ results_grid_shape = ctx .make_node (
1849
+ "Concat" , [batch_size , max_per_class_times_classes ], attr = {'axis' : 0 }).output [0 ]
1850
+ scores_by_batch_empty = ctx .make_node (
1851
+ "ConstantOfShape" , [results_grid_shape ], attr = {"value" : zero_tensor_float }).output [0 ]
1852
+ idx_by_batch_empty = ctx .make_node (
1853
+ "ConstantOfShape" , [results_grid_shape ], attr = {"value" : neg_one_tensor_float }).output [0 ]
1854
+
1855
+ scores_by_batch = ctx .make_node ("ScatterND" , [scores_by_batch_empty , combined_idx , selected_scores ]).output [0 ]
1856
+ idx_by_batch = ctx .make_node ("ScatterND" , [idx_by_batch_empty , combined_idx , selected_range ]).output [0 ]
1857
+
1858
+ k_val = ctx .make_node ("Min" , [max_total_size , max_per_class_times_classes ]).output [0 ]
1859
+
1860
+ # shape: [batch_size, k_val]
1861
+ top_k_vals , top_k_indices = \
1862
+ ctx .make_node ("TopK" , [scores_by_batch , k_val ], attr = {'axis' : 1 }, output_count = 2 ).output
1863
+
1864
+ top_k_selected_indices = ctx .make_node ("GatherElements" , [idx_by_batch , top_k_indices ], attr = {'axis' : 1 },
1865
+ op_name_scope = node .name ).output [0 ]
1866
+
1867
+ target_size = max_total_size
1868
+ if pad_per_class :
1869
+ target_size = k_val
1870
+
1871
+ pad_amt = ctx .make_node ("Sub" , [target_size , k_val ]).output [0 ]
1872
+ pads_const = ctx .make_const (utils .make_name ("pad_const" ), np .array ([0 , 0 , 0 ], np .int64 )).output [0 ]
1873
+ pads = ctx .make_node ("Concat" , [pads_const , pad_amt ], attr = {'axis' : 0 }).output [0 ]
1874
+
1875
+ top_scores_pad = ctx .make_node ("Pad" , [top_k_vals , pads , const_zero_float ]).output [0 ]
1876
+ top_indices_pad = ctx .make_node ("Pad" , [top_k_selected_indices , pads , const_neg_one ]).output [0 ]
1877
+ top_indices_increment = ctx .make_node ("Add" , [top_indices_pad , const_one ]).output [0 ]
1878
+
1879
+ valid_indices = ctx .make_node ("Greater" , [top_k_selected_indices , const_neg_one ]).output [0 ]
1880
+ valid_indices_int = ctx .make_node ("Cast" , [valid_indices ], attr = {'to' : TensorProto .INT32 }).output [0 ]
1881
+ # shape: [batch_size]
1882
+ valid_indices_cnt = GraphBuilder (ctx ).make_reduce_sum (
1883
+ {"data" : valid_indices_int , "axes" : [- 1 ], "keepdims" : 0 , "noop_with_empty_axes" : 1 })
1884
+
1885
+ box_pads = ctx .make_const (utils .make_name ("pad_const" ), np .array ([1 , 0 , 0 , 0 ], np .int64 )).output [0 ]
1886
+ class_pads = ctx .make_const (utils .make_name ("pad_const" ), np .array ([1 , 0 ], np .int64 )).output [0 ]
1887
+ clipped_boxes_pad = ctx .make_node ("Pad" , [clipped_boxes , box_pads , const_zero_float ]).output [0 ]
1888
+ selected_classes_pad = ctx .make_node ("Pad" , [selected_classes , class_pads , const_zero ]).output [0 ]
1889
+ nmsed_boxes = ctx .make_node ("Gather" , [clipped_boxes_pad , top_indices_increment ], attr = {'axis' : 0 },
1890
+ op_name_scope = node .name ).output [0 ]
1891
+ nmsed_classes = ctx .make_node ("Gather" , [selected_classes_pad , top_indices_increment ], attr = {'axis' : 0 },
1892
+ op_name_scope = node .name ).output [0 ]
1893
+ nmsed_classes_float = ctx .make_node ("Cast" , [nmsed_classes ], attr = {'to' : TensorProto .FLOAT }).output [0 ]
1894
+
1895
+ ctx .replace_all_inputs (node .output [0 ], nmsed_boxes )
1896
+ ctx .replace_all_inputs (node .output [1 ], top_scores_pad )
1897
+ ctx .replace_all_inputs (node .output [2 ], nmsed_classes_float )
1898
+ ctx .replace_all_inputs (node .output [3 ], valid_indices_cnt )
1899
+ ctx .remove_node (node .name )
1900
+
1901
+
1764
1902
@tf_op ("ReverseSequence" )
1765
1903
class ReverseSequence :
1766
1904
@classmethod
0 commit comments