@@ -1859,3 +1859,99 @@ def version_9(cls, ctx, node, **kwargs):
1859
1859
label_node = ctx .make_node ("Cast" , label_node .output , attr = {"to" : logit_dtype }, dtypes = [logit_dtype ])
1860
1860
1861
1861
_make_sparse_softmax_cross_entropy_with_logits (ctx , label_node , logit_node , node )
1862
+
1863
+
1864
+ @tf_op ("CTCGreedyDecoder" )
1865
+ class CTCGreedyDecoder :
1866
+ @classmethod
1867
+ def version_11 (cls , ctx , node , ** kwargs ):
1868
+ # shape = [max_time, batch_size, num_classes]
1869
+ inp = node .input [0 ]
1870
+ # shape = [batch_size]
1871
+ seq_lens = node .input [1 ]
1872
+ seq_lens_int64 = ctx .make_node ("Cast" , [seq_lens ], attr = {"to" : TensorProto .INT64 }).output [0 ]
1873
+ # shape = [1, batch_size, 1]
1874
+ seq_lens_unsq = GraphBuilder (ctx ).make_unsqueeze ({"data" : seq_lens_int64 , "axes" : [0 , 2 ]})
1875
+
1876
+ merge_repeated = node .get_attr_value ("merge_repeated" , False )
1877
+
1878
+ inp_shape = ctx .make_node ("Shape" , [inp ]).output [0 ]
1879
+ max_time_unsq , num_batch_unsq , num_classes_unsq = ctx .make_node ("Split" , [inp_shape ], output_count = 3 ).output
1880
+ max_time = GraphBuilder (ctx ).make_squeeze ({"data" : max_time_unsq , "axes" : [0 ]})
1881
+ num_batch = GraphBuilder (ctx ).make_squeeze ({"data" : num_batch_unsq , "axes" : [0 ]})
1882
+ num_classes = GraphBuilder (ctx ).make_squeeze ({"data" : num_classes_unsq , "axes" : [0 ]})
1883
+ const_one = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , np .int64 )).output [0 ]
1884
+ const_one_unsq = ctx .make_const (utils .make_name ("const_one" ), np .array ([1 ], np .int64 )).output [0 ]
1885
+ const_zero = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , np .int64 )).output [0 ]
1886
+ blank_label = ctx .make_node ("Sub" , [num_classes , const_one ]).output [0 ]
1887
+ time = ctx .make_node ("Range" , [const_zero , max_time , const_one ]).output [0 ]
1888
+ batch = ctx .make_node ("Range" , [const_zero , num_batch , const_one ]).output [0 ]
1889
+ # shape = [max_time, 1, 1]
1890
+ time_unsq = GraphBuilder (ctx ).make_unsqueeze ({"data" : time , "axes" : [1 , 2 ]})
1891
+ valid_elts = ctx .make_node ("Less" , [time_unsq , seq_lens_unsq ]).output [0 ]
1892
+ # shape = [max_time, batch_size, 1]
1893
+ valid_mask = ctx .make_node ("Cast" , [valid_elts ], attr = {"to" : TensorProto .FLOAT }).output [0 ]
1894
+ # shape = [max_time, batch_size, num_classes]
1895
+ valid_inp = ctx .make_node ("Mul" , [inp , valid_mask ]).output [0 ]
1896
+
1897
+ # shape = [max_time, batch_size, 1]
1898
+ max_val , max_idx = ctx .make_node ("TopK" , [valid_inp , const_one_unsq ], attr = {"axis" : 2 },
1899
+ output_count = 2 , op_name_scope = node .name ).output
1900
+ # shape = [batch_size, 1]
1901
+ sum_max = GraphBuilder (ctx ).make_reduce_sum ({"data" : max_val , "axes" : [0 ], "keepdims" : False })
1902
+ sum_max_neg = ctx .make_node ("Neg" , [sum_max ]).output [0 ]
1903
+
1904
+ valid_elts_sq = GraphBuilder (ctx ).make_squeeze ({"data" : valid_elts , "axes" : [2 ]})
1905
+ max_idx_sq = GraphBuilder (ctx ).make_squeeze ({"data" : max_idx , "axes" : [2 ]})
1906
+ # shape = [batch_size, max_time]
1907
+ max_idx_trans = ctx .make_node ("Transpose" , [max_idx_sq ], attr = {"perm" : [1 , 0 ]}).output [0 ]
1908
+ valid_elts_trans = ctx .make_node ("Transpose" , [valid_elts_sq ], attr = {"perm" : [1 , 0 ]}).output [0 ]
1909
+
1910
+ # value = [batch_size, max_time]
1911
+ idx_shape = ctx .make_node ("Shape" , [max_idx_trans ]).output [0 ]
1912
+ keep_idx = ctx .make_node ("Less" , [max_idx_trans , blank_label ]).output [0 ]
1913
+ keep_idx = ctx .make_node ("And" , [keep_idx , valid_elts_trans ]).output [0 ]
1914
+
1915
+ if merge_repeated :
1916
+ # val = [batch_size, 1]
1917
+ shift_row_shape = ctx .make_node ("Concat" , [num_batch_unsq , const_one_unsq ], attr = {'axis' : 0 }).output [0 ]
1918
+ neg_one_tensor = helper .make_tensor ("value" , onnx_pb .TensorProto .INT64 , dims = [1 ], vals = [- 1 ])
1919
+ # shape = [batch_size, 1]
1920
+ neg_ones = ctx .make_node ("ConstantOfShape" , [shift_row_shape ], {'value' : neg_one_tensor }).output [0 ]
1921
+ max_idx_cut = GraphBuilder (ctx ).make_slice (
1922
+ {"data" : max_idx_trans , "starts" : [0 ], "ends" : [- 1 ], "axes" : [1 ]})
1923
+ # shape = [batch_size, max_time]
1924
+ max_idx_shift = ctx .make_node ("Concat" , [neg_ones , max_idx_cut ], attr = {"axis" : 1 }).output [0 ]
1925
+ repeat_elts = ctx .make_node ("Equal" , [max_idx_shift , max_idx_trans ]).output [0 ]
1926
+ not_repeat = ctx .make_node ("Not" , [repeat_elts ]).output [0 ]
1927
+ keep_idx = ctx .make_node ("And" , [keep_idx , not_repeat ]).output [0 ]
1928
+
1929
+ batch_unsq = GraphBuilder (ctx ).make_unsqueeze ({"data" : batch , "axes" : [1 ]})
1930
+ batch_expand = ctx .make_node ("Expand" , [batch_unsq , idx_shape ]).output [0 ]
1931
+ keep_idx_int = ctx .make_node ("Cast" , [keep_idx ], attr = {"to" : TensorProto .INT64 }).output [0 ]
1932
+ filtered_time = ctx .make_node ("CumSum" , [keep_idx_int , const_one ], attr = {"exclusive" : True }).output [0 ]
1933
+
1934
+ flat_shape = ctx .make_const (utils .make_name ("const_neg_one" ), np .array ([- 1 ], np .int64 )).output [0 ]
1935
+ flat_shape2 = ctx .make_const (utils .make_name ("const_shape" ), np .array ([- 1 , 1 ], np .int64 )).output [0 ]
1936
+ idx_flat = ctx .make_node ("Reshape" , [max_idx_trans , flat_shape ]).output [0 ]
1937
+ keep_idx_flat = ctx .make_node ("Reshape" , [keep_idx , flat_shape ]).output [0 ]
1938
+ time_flat = ctx .make_node ("Reshape" , [filtered_time , flat_shape2 ]).output [0 ]
1939
+ batch_flat = ctx .make_node ("Reshape" , [batch_expand , flat_shape2 ]).output [0 ]
1940
+ sparse_idx = ctx .make_node ("Concat" , [batch_flat , time_flat ], attr = {'axis' : 1 }).output [0 ]
1941
+ idx_compress = ctx .make_node ("Compress" , [idx_flat , keep_idx_flat ], attr = {'axis' : 0 }, shapes = [[- 1 ]],
1942
+ op_name_scope = node .name ).output [0 ]
1943
+ sparse_idx_compress = ctx .make_node ("Compress" , [sparse_idx , keep_idx_flat ], attr = {'axis' : 0 }, shapes = [[- 1 , 2 ]],
1944
+ op_name_scope = node .name ).output [0 ]
1945
+ max_sparse_idx = ctx .make_node ("ReduceMax" , [sparse_idx_compress ],
1946
+ attr = {'axes' : [0 ], 'keepdims' : False }).output [0 ]
1947
+ max_time = GraphBuilder (ctx ).make_slice (
1948
+ {"data" : max_sparse_idx , "starts" : [1 ], "ends" : [2 ], "axes" : [0 ]})
1949
+ max_time_inc = ctx .make_node ("Add" , [max_time , const_one ]).output [0 ]
1950
+ sparse_shape = ctx .make_node ("Concat" , [num_batch_unsq , max_time_inc ], attr = {'axis' : 0 }).output [0 ]
1951
+
1952
+ ctx .replace_all_inputs (node .output [0 ], sparse_idx_compress )
1953
+ ctx .replace_all_inputs (node .output [1 ], idx_compress )
1954
+ ctx .replace_all_inputs (node .output [2 ], sparse_shape )
1955
+ ctx .replace_all_inputs (node .output [3 ], sum_max_neg )
1956
+
1957
+ ctx .remove_node (node .name )
0 commit comments