Skip to content

Commit b4224a5

Browse files
Implement conversion of CTCGreedyDecoder (#1530)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent b106603 commit b4224a5

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

tests/test_backend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4085,6 +4085,32 @@ def func(x):
40854085
return tf.identity(y, name=_TFOUTPUT)
40864086
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
40874087

4088+
@check_opset_min_version(11, "Range")
4089+
def test_ctc_greedy_decoder(self):
4090+
x_val = np.random.uniform(size=(3, 4, 5)).astype(np.float32)
4091+
s_val = np.array([3, 3, 2, 3], np.int32)
4092+
def func(x, s):
4093+
[decoded], logits = tf.nn.ctc_greedy_decoder(x, s, merge_repeated=False)
4094+
r1 = tf.identity(decoded.indices, name=_TFOUTPUT)
4095+
r2 = tf.identity(decoded.values, name=_TFOUTPUT1)
4096+
r3 = tf.identity(decoded.dense_shape, name=_TFOUTPUT2)
4097+
r4 = tf.identity(logits, name=_TFOUTPUT3)
4098+
return r1, r2, r3, r4
4099+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2, _OUTPUT3], {_INPUT: x_val, _INPUT1: s_val})
4100+
4101+
@check_opset_min_version(11, "Range")
4102+
def test_ctc_greedy_decoder_merge_repeated(self):
4103+
x_val = np.random.uniform(size=(6, 4, 5)).astype(np.float32)
4104+
s_val = np.array([5, 6, 4, 6], np.int32)
4105+
def func(x, s):
4106+
[decoded], logits = tf.nn.ctc_greedy_decoder(x, s, merge_repeated=True)
4107+
r1 = tf.identity(decoded.indices, name=_TFOUTPUT)
4108+
r2 = tf.identity(decoded.values, name=_TFOUTPUT1)
4109+
r3 = tf.identity(decoded.dense_shape, name=_TFOUTPUT2)
4110+
r4 = tf.identity(logits, name=_TFOUTPUT3)
4111+
return r1, r2, r3, r4
4112+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2, _OUTPUT3], {_INPUT: x_val, _INPUT1: s_val})
4113+
40884114
# test for gemm pattern0: alpha*A*B + beta*C
40894115
def test_gemm_pattern0(self):
40904116
max_number = 10

tf2onnx/onnx_opset/nn.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,3 +1859,99 @@ def version_9(cls, ctx, node, **kwargs):
18591859
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])
18601860

18611861
_make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)
1862+
1863+
1864+
@tf_op("CTCGreedyDecoder")
1865+
class CTCGreedyDecoder:
1866+
@classmethod
1867+
def version_11(cls, ctx, node, **kwargs):
1868+
# shape = [max_time, batch_size, num_classes]
1869+
inp = node.input[0]
1870+
# shape = [batch_size]
1871+
seq_lens = node.input[1]
1872+
seq_lens_int64 = ctx.make_node("Cast", [seq_lens], attr={"to": TensorProto.INT64}).output[0]
1873+
# shape = [1, batch_size, 1]
1874+
seq_lens_unsq = GraphBuilder(ctx).make_unsqueeze({"data": seq_lens_int64, "axes": [0, 2]})
1875+
1876+
merge_repeated = node.get_attr_value("merge_repeated", False)
1877+
1878+
inp_shape = ctx.make_node("Shape", [inp]).output[0]
1879+
max_time_unsq, num_batch_unsq, num_classes_unsq = ctx.make_node("Split", [inp_shape], output_count=3).output
1880+
max_time = GraphBuilder(ctx).make_squeeze({"data": max_time_unsq, "axes": [0]})
1881+
num_batch = GraphBuilder(ctx).make_squeeze({"data": num_batch_unsq, "axes": [0]})
1882+
num_classes = GraphBuilder(ctx).make_squeeze({"data": num_classes_unsq, "axes": [0]})
1883+
const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
1884+
const_one_unsq = ctx.make_const(utils.make_name("const_one"), np.array([1], np.int64)).output[0]
1885+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
1886+
blank_label = ctx.make_node("Sub", [num_classes, const_one]).output[0]
1887+
time = ctx.make_node("Range", [const_zero, max_time, const_one]).output[0]
1888+
batch = ctx.make_node("Range", [const_zero, num_batch, const_one]).output[0]
1889+
# shape = [max_time, 1, 1]
1890+
time_unsq = GraphBuilder(ctx).make_unsqueeze({"data": time, "axes": [1, 2]})
1891+
valid_elts = ctx.make_node("Less", [time_unsq, seq_lens_unsq]).output[0]
1892+
# shape = [max_time, batch_size, 1]
1893+
valid_mask = ctx.make_node("Cast", [valid_elts], attr={"to": TensorProto.FLOAT}).output[0]
1894+
# shape = [max_time, batch_size, num_classes]
1895+
valid_inp = ctx.make_node("Mul", [inp, valid_mask]).output[0]
1896+
1897+
# shape = [max_time, batch_size, 1]
1898+
max_val, max_idx = ctx.make_node("TopK", [valid_inp, const_one_unsq], attr={"axis": 2},
1899+
output_count=2, op_name_scope=node.name).output
1900+
# shape = [batch_size, 1]
1901+
sum_max = GraphBuilder(ctx).make_reduce_sum({"data": max_val, "axes": [0], "keepdims": False})
1902+
sum_max_neg = ctx.make_node("Neg", [sum_max]).output[0]
1903+
1904+
valid_elts_sq = GraphBuilder(ctx).make_squeeze({"data": valid_elts, "axes": [2]})
1905+
max_idx_sq = GraphBuilder(ctx).make_squeeze({"data": max_idx, "axes": [2]})
1906+
# shape = [batch_size, max_time]
1907+
max_idx_trans = ctx.make_node("Transpose", [max_idx_sq], attr={"perm": [1, 0]}).output[0]
1908+
valid_elts_trans = ctx.make_node("Transpose", [valid_elts_sq], attr={"perm": [1, 0]}).output[0]
1909+
1910+
# value = [batch_size, max_time]
1911+
idx_shape = ctx.make_node("Shape", [max_idx_trans]).output[0]
1912+
keep_idx = ctx.make_node("Less", [max_idx_trans, blank_label]).output[0]
1913+
keep_idx = ctx.make_node("And", [keep_idx, valid_elts_trans]).output[0]
1914+
1915+
if merge_repeated:
1916+
# val = [batch_size, 1]
1917+
shift_row_shape = ctx.make_node("Concat", [num_batch_unsq, const_one_unsq], attr={'axis': 0}).output[0]
1918+
neg_one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[-1])
1919+
# shape = [batch_size, 1]
1920+
neg_ones = ctx.make_node("ConstantOfShape", [shift_row_shape], {'value': neg_one_tensor}).output[0]
1921+
max_idx_cut = GraphBuilder(ctx).make_slice(
1922+
{"data": max_idx_trans, "starts": [0], "ends": [-1], "axes": [1]})
1923+
# shape = [batch_size, max_time]
1924+
max_idx_shift = ctx.make_node("Concat", [neg_ones, max_idx_cut], attr={"axis": 1}).output[0]
1925+
repeat_elts = ctx.make_node("Equal", [max_idx_shift, max_idx_trans]).output[0]
1926+
not_repeat = ctx.make_node("Not", [repeat_elts]).output[0]
1927+
keep_idx = ctx.make_node("And", [keep_idx, not_repeat]).output[0]
1928+
1929+
batch_unsq = GraphBuilder(ctx).make_unsqueeze({"data": batch, "axes": [1]})
1930+
batch_expand = ctx.make_node("Expand", [batch_unsq, idx_shape]).output[0]
1931+
keep_idx_int = ctx.make_node("Cast", [keep_idx], attr={"to": TensorProto.INT64}).output[0]
1932+
filtered_time = ctx.make_node("CumSum", [keep_idx_int, const_one], attr={"exclusive": True}).output[0]
1933+
1934+
flat_shape = ctx.make_const(utils.make_name("const_neg_one"), np.array([-1], np.int64)).output[0]
1935+
flat_shape2 = ctx.make_const(utils.make_name("const_shape"), np.array([-1, 1], np.int64)).output[0]
1936+
idx_flat = ctx.make_node("Reshape", [max_idx_trans, flat_shape]).output[0]
1937+
keep_idx_flat = ctx.make_node("Reshape", [keep_idx, flat_shape]).output[0]
1938+
time_flat = ctx.make_node("Reshape", [filtered_time, flat_shape2]).output[0]
1939+
batch_flat = ctx.make_node("Reshape", [batch_expand, flat_shape2]).output[0]
1940+
sparse_idx = ctx.make_node("Concat", [batch_flat, time_flat], attr={'axis': 1}).output[0]
1941+
idx_compress = ctx.make_node("Compress", [idx_flat, keep_idx_flat], attr={'axis': 0}, shapes=[[-1]],
1942+
op_name_scope=node.name).output[0]
1943+
sparse_idx_compress = ctx.make_node("Compress", [sparse_idx, keep_idx_flat], attr={'axis': 0}, shapes=[[-1, 2]],
1944+
op_name_scope=node.name).output[0]
1945+
max_sparse_idx = ctx.make_node("ReduceMax", [sparse_idx_compress],
1946+
attr={'axes': [0], 'keepdims': False}).output[0]
1947+
max_time = GraphBuilder(ctx).make_slice(
1948+
{"data": max_sparse_idx, "starts": [1], "ends": [2], "axes": [0]})
1949+
max_time_inc = ctx.make_node("Add", [max_time, const_one]).output[0]
1950+
sparse_shape = ctx.make_node("Concat", [num_batch_unsq, max_time_inc], attr={'axis': 0}).output[0]
1951+
1952+
ctx.replace_all_inputs(node.output[0], sparse_idx_compress)
1953+
ctx.replace_all_inputs(node.output[1], idx_compress)
1954+
ctx.replace_all_inputs(node.output[2], sparse_shape)
1955+
ctx.replace_all_inputs(node.output[3], sum_max_neg)
1956+
1957+
ctx.remove_node(node.name)

0 commit comments

Comments
 (0)