Skip to content

Commit 47753a9

Browse files
Merge pull request #7527 from wanghaoshuang/ctc_greedy_decode
Add CTC align op
2 parents 6720e67 + 6089b50 commit 47753a9

File tree

5 files changed

+338
-5
lines changed

5 files changed

+338
-5
lines changed

paddle/operators/ctc_align_op.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/ctc_align_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class CTCAlignOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("Input"),
26+
"Input of CTCAlignOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Output"),
28+
"Output of CTCAlignOp should not be null.");
29+
30+
auto input_dims = ctx->GetInputDim("Input");
31+
32+
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
33+
ctx->SetOutputDim("Output", input_dims);
34+
}
35+
36+
protected:
37+
framework::OpKernelType GetExpectedKernelType(
38+
const framework::ExecutionContext& ctx) const override {
39+
return framework::OpKernelType(
40+
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
41+
ctx.device_context());
42+
}
43+
};
44+
45+
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
46+
public:
47+
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
48+
: OpProtoAndCheckerMaker(proto, op_checker) {
49+
AddInput("Input",
50+
"(LodTensor, default: LoDTensor<int>), Its shape is "
51+
"[Lp, 1], where Lp is the sum of all input sequences' length.");
52+
AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
53+
AddAttr<int>("blank",
54+
"(int, default: 0), the blank label setted in Connectionist "
55+
"Temporal Classification (CTC) op.")
56+
.SetDefault(0);
57+
AddAttr<bool>("merge_repeated",
58+
"(bool, default: true), whether to "
59+
"merge repeated elements between two blanks. ")
60+
.SetDefault(true);
61+
AddComment(R"DOC(
62+
CTCAlign op is used to merge repeated elements between two blanks
63+
and then delete all blanks in sequence.
64+
65+
Given:
66+
Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6,
67+
6, 0, 0, 7, 7, 7, 0]
68+
Input.dims = {18, 1}
69+
Input.LoD = [[0, 11, 18]]
70+
71+
And:
72+
blank = 0
73+
merge_repeated = True
74+
75+
Then:
76+
Output.data = [1, 2, 4, 4, 5, 6,
77+
6, 7]
78+
Output.dims = {8, 1}
79+
Output.LoD = [[0, 6, 8]]
80+
81+
)DOC");
82+
}
83+
};
84+
85+
} // namespace operators
86+
} // namespace paddle
87+
88+
namespace ops = paddle::operators;
89+
REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker,
90+
paddle::framework::EmptyGradOpMaker);
91+
REGISTER_OP_CPU_KERNEL(
92+
ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>,
93+
ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/operators/ctc_align_op.cu

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <stdio.h>
16+
#include <thrust/device_vector.h>
17+
#include <thrust/host_vector.h>
18+
#include "paddle/operators/ctc_align_op.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename T>
24+
__global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
25+
const size_t num_seq, size_t* lod0,
26+
const int blank, const int merge_repeated,
27+
size_t* out_lod0, T* output) {
28+
int ouput_idx = 0;
29+
out_lod0[0] = 0;
30+
31+
for (int i = 0; i < num_seq; ++i) {
32+
T pre_token = -1;
33+
for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
34+
if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
35+
output[ouput_idx] = tokens[j];
36+
++ouput_idx;
37+
}
38+
pre_token = tokens[j];
39+
}
40+
out_lod0[i + 1] = ouput_idx;
41+
}
42+
}
43+
44+
template <typename T>
45+
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
46+
public:
47+
void Compute(const framework::ExecutionContext& ctx) const override {
48+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
49+
"It must use CUDAPlace.");
50+
const size_t level = 0;
51+
auto* input = ctx.Input<LoDTensor>("Input");
52+
auto* output = ctx.Output<LoDTensor>("Output");
53+
auto input_lod = framework::ToAbsOffset(input->lod());
54+
55+
const T* tokens = input->data<T>();
56+
const int64_t num_tokens = input->dims()[0];
57+
const size_t num_seq = input_lod[level].size() - 1;
58+
59+
const int blank = ctx.Attr<int>("blank");
60+
const int merge_repeated =
61+
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
62+
63+
// prepare a lod to record lod information while merging elements
64+
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
65+
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
66+
67+
// merge elements and delete blank
68+
T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace());
69+
70+
auto stream = ctx.cuda_device_context().stream();
71+
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
72+
num_tokens, tokens, num_seq, input_lod[level].data(), blank,
73+
merge_repeated, dev_out_lod0_ptr, output_data);
74+
75+
// set output lod
76+
thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(),
77+
dev_out_lod0.end());
78+
framework::LoD out_lod;
79+
out_lod.push_back(host_out_lod0);
80+
output->set_lod(out_lod);
81+
82+
// resize output dims
83+
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
84+
}
85+
};
86+
87+
} // namespace operators
88+
} // namespace paddle
89+
90+
REGISTER_OP_CUDA_KERNEL(ctc_align, paddle::operators::CTCAlignOpCUDAKernel<int>,
91+
paddle::operators::CTCAlignOpCUDAKernel<int64_t>);

paddle/operators/ctc_align_op.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string.h>
18+
#include "paddle/framework/op_registry.h"
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
using LoDTensor = framework::LoDTensor;
24+
25+
template <typename DeviceContext, typename T>
26+
class CTCAlignKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext& ctx) const override {
29+
auto* input = ctx.Input<LoDTensor>("Input");
30+
auto* output = ctx.Output<LoDTensor>("Output");
31+
const size_t level = 0;
32+
auto input_lod = framework::ToAbsOffset(input->lod());
33+
34+
// check input dims and lod
35+
auto input_dims = input->dims();
36+
PADDLE_ENFORCE_EQ(input_dims[0],
37+
static_cast<int64_t>(input_lod[level].back()),
38+
"The first dimension of Input(Input) should be equal to "
39+
"the sum of all sequences' lengths.");
40+
41+
const size_t num_sequences = input_lod[level].size() - 1;
42+
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
43+
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
44+
45+
// merge repeated tokens and delete blank
46+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
47+
size_t output_idx = 0;
48+
std::vector<size_t> output_lod0(1, 0);
49+
const T* input_data = input->data<T>();
50+
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
51+
T prev_token = -1;
52+
for (size_t i = input_lod[level][seq_idx];
53+
i < input_lod[level][seq_idx + 1]; ++i) {
54+
if (input_data[i] != blank &&
55+
!(merge_repeated && input_data[i] == prev_token)) {
56+
output_data[output_idx] = input_data[i];
57+
++output_idx;
58+
}
59+
prev_token = input_data[i];
60+
}
61+
output_lod0.push_back(output_idx);
62+
}
63+
64+
// set output lod
65+
framework::LoD output_lod;
66+
output_lod.push_back(output_lod0);
67+
output->set_lod(output_lod);
68+
69+
// resize output dims
70+
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
71+
}
72+
};
73+
74+
} // namespace operators
75+
} // namespace paddle

paddle/operators/sequence_expand_op.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ This operator expands input(X) according to LOD of input(Y).
5858
Following are cases to better explain how this works:
5959
Case 1:
6060
61-
Given 2-level a LoDTensor input(X)
61+
Given a 2-level LoDTensor input(X)
6262
X.lod = [[0, 2, 3],
6363
[0, 1, 3, 4]]
6464
X.data = [a, b, c, d]
@@ -75,9 +75,8 @@ then we get 2-level LoDTensor
7575
7676
Case 2:
7777
78-
Given a 0-level LoDTensor input(X)
78+
Given a common Tensor input(X)
7979
X.data = [a, b, c]
80-
X.lod = NULL
8180
X.dims = [3, 1]
8281
and input(Y)
8382
Y.lod = [[0, 2, 3, 6]]
@@ -89,9 +88,8 @@ then we get 1-level LoDTensor
8988
9089
Case 3:
9190
92-
Given a 0-level LoDTensor input(X)
91+
Given a common Tensor input(X)
9392
X.data = [[a, b], [c, d], [e, f]]
94-
X.lod = NULL
9593
X.dims = [3, 2]
9694
and input(Y)
9795
Y.lod = [[0, 2, 3, 6]]
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
#Licensed under the Apache License, Version 2.0 (the "License");
4+
#you may not use this file except in compliance with the License.
5+
#You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
#Unless required by applicable law or agreed to in writing, software
10+
#distributed under the License is distributed on an "AS IS" BASIS,
11+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#See the License for the specific language governing permissions and
13+
#limitations under the License.
14+
15+
import sys
16+
import unittest
17+
import numpy as np
18+
from op_test import OpTest
19+
from test_softmax_op import stable_softmax
20+
21+
22+
def CTCAlign(input, lod, blank, merge_repeated):
23+
lod0 = lod[0]
24+
result = []
25+
for i in range(len(lod0) - 1):
26+
prev_token = -1
27+
for j in range(lod0[i], lod0[i + 1]):
28+
token = input[j][0]
29+
if (token != blank) and not (merge_repeated and
30+
token == prev_token):
31+
result.append(token)
32+
prev_token = token
33+
result = np.array(result).reshape([len(result), 1]).astype("int32")
34+
return result
35+
36+
37+
class TestCTCAlignOp(OpTest):
38+
def config(self):
39+
self.op_type = "ctc_align"
40+
self.input_lod = [[0, 11, 18]]
41+
self.blank = 0
42+
self.merge_repeated = False
43+
self.input = np.array(
44+
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
45+
[18, 1]).astype("int32")
46+
47+
def setUp(self):
48+
self.config()
49+
output = CTCAlign(self.input, self.input_lod, self.blank,
50+
self.merge_repeated)
51+
52+
self.inputs = {"Input": (self.input, self.input_lod), }
53+
self.outputs = {"Output": output}
54+
self.attrs = {
55+
"blank": self.blank,
56+
"merge_repeated": self.merge_repeated
57+
}
58+
59+
def test_check_output(self):
60+
self.check_output()
61+
pass
62+
63+
64+
class TestCTCAlignOpCase1(TestCTCAlignOp):
65+
def config(self):
66+
self.op_type = "ctc_align"
67+
self.input_lod = [[0, 11, 19]]
68+
self.blank = 0
69+
self.merge_repeated = True
70+
self.input = np.array(
71+
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0, 0]).reshape(
72+
[19, 1]).astype("int32")
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)