Skip to content

Commit 7150289

Browse files
committed
Refine CPU kernel
1. Allocate memory for output before compute. 2. Rename 'ctc_decode' to 'ctc_align'
1 parent adcfde3 commit 7150289

File tree

4 files changed

+28
-33
lines changed

4 files changed

+28
-33
lines changed

paddle/operators/ctc_decode_op.cc renamed to paddle/operators/ctc_align_op.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/operators/ctc_decode_op.h"
15+
#include "paddle/operators/ctc_align_op.h"
1616

1717
namespace paddle {
1818
namespace operators {
1919

20-
class CTCDecodeOp : public framework::OperatorWithKernel {
20+
class CTCAlignOp : public framework::OperatorWithKernel {
2121
public:
2222
using framework::OperatorWithKernel::OperatorWithKernel;
2323

2424
void InferShape(framework::InferShapeContext* ctx) const override {
2525
PADDLE_ENFORCE(ctx->HasInput("Input"),
26-
"Input of CTCDecodeOp should not be null.");
26+
"Input of CTCAlignOp should not be null.");
2727
PADDLE_ENFORCE(ctx->HasOutput("Output"),
28-
"Output of CTCDecodeOp should not be null.");
28+
"Output of CTCAlignOp should not be null.");
2929

3030
auto input_dims = ctx->GetInputDim("Input");
3131

@@ -42,14 +42,14 @@ class CTCDecodeOp : public framework::OperatorWithKernel {
4242
}
4343
};
4444

45-
class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
45+
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
4646
public:
47-
CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
47+
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
4848
: OpProtoAndCheckerMaker(proto, op_checker) {
4949
AddInput("Input",
5050
"(LodTensor, default: LoDTensor<int>), Its shape is "
5151
"[Lp, 1], where Lp is the sum of all input sequences' length.");
52-
AddOutput("Output", "(Tensor, default: Tensor<int>), The decode result.");
52+
AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
5353
AddAttr<int>("blank",
5454
"(int, default: 0), the blank label setted in Connectionist "
5555
"Temporal Classification (CTC) op.")
@@ -59,7 +59,7 @@ class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
5959
"merge repeated elements between two blanks. ")
6060
.SetDefault(true);
6161
AddComment(R"DOC(
62-
CTCDecoder is used to merge repeated elements between two blanks
62+
CTCAlign op is used to merge repeated elements between two blanks
6363
and then delete all blanks in sequence.
6464
6565
Given:
@@ -86,7 +86,7 @@ and then delete all blanks in sequence.
8686
} // namespace paddle
8787

8888
namespace ops = paddle::operators;
89-
REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker,
89+
REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker,
9090
paddle::framework::EmptyGradOpMaker);
9191
REGISTER_OP_CPU_KERNEL(
92-
ctc_decode, ops::CTCDecodeKernel<paddle::platform::CPUDeviceContext, int>);
92+
ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>);

paddle/operators/ctc_decode_op.cu renamed to paddle/operators/ctc_align_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ limitations under the License. */
1515
#include <stdio.h>
1616
#include <thrust/device_vector.h>
1717
#include <thrust/host_vector.h>
18-
#include "paddle/operators/ctc_decode_op.h"
18+
#include "paddle/operators/ctc_align_op.h"
1919

2020
namespace paddle {
2121
namespace operators {
@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
4242
}
4343

4444
template <typename T>
45-
class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
45+
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
4646
public:
4747
void Compute(const framework::ExecutionContext& ctx) const override {
4848
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
@@ -87,5 +87,5 @@ class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
8787
} // namespace operators
8888
} // namespace paddle
8989

90-
REGISTER_OP_CUDA_KERNEL(ctc_decode,
91-
paddle::operators::CTCDecodeOpCUDAKernel<int>);
90+
REGISTER_OP_CUDA_KERNEL(ctc_align,
91+
paddle::operators::CTCAlignOpCUDAKernel<int>);

paddle/operators/ctc_decode_op.h renamed to paddle/operators/ctc_align_op.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Tensor = framework::Tensor;
2323
using LoDTensor = framework::LoDTensor;
2424

2525
template <typename DeviceContext, typename T>
26-
class CTCDecodeKernel : public framework::OpKernel<T> {
26+
class CTCAlignKernel : public framework::OpKernel<T> {
2727
public:
2828
void Compute(const framework::ExecutionContext& ctx) const override {
2929
auto* input = ctx.Input<LoDTensor>("Input");
@@ -43,7 +43,8 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
4343
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
4444

4545
// merge repeated tokens and delete blank
46-
std::vector<std::vector<int>> pathes(num_sequences);
46+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
47+
size_t output_idx = 0;
4748
std::vector<size_t> output_lod0(1, 0);
4849
const T* input_data = input->data<T>();
4950
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
@@ -52,11 +53,12 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
5253
i < input_lod[level][seq_idx + 1]; ++i) {
5354
if (input_data[i] != blank &&
5455
!(merge_repeated && input_data[i] == prev_token)) {
55-
pathes[seq_idx].push_back(input_data[i]);
56+
output_data[output_idx] = input_data[i];
57+
++output_idx;
5658
}
5759
prev_token = input_data[i];
5860
}
59-
output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size());
61+
output_lod0.push_back(output_idx);
6062
}
6163

6264
// set output lod
@@ -65,14 +67,7 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
6567
output->set_lod(output_lod);
6668

6769
// resize output dims
68-
T* output_data = output->mutable_data<T>(
69-
{static_cast<int64_t>(output_lod0.back()), 1}, ctx.GetPlace());
70-
71-
// copy result to output
72-
for (int i = 0; i < num_sequences; ++i) {
73-
memcpy(output_data + output_lod0[i], pathes[i].data(),
74-
sizeof(int) * pathes[i].size());
75-
}
70+
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
7671
}
7772
};
7873

python/paddle/v2/fluid/tests/test_ctc_decode.py renamed to python/paddle/v2/fluid/tests/test_ctc_align.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from test_softmax_op import stable_softmax
66

77

8-
def CTCDecode(input, lod, blank, merge_repeated):
8+
def CTCAlign(input, lod, blank, merge_repeated):
99
lod0 = lod[0]
1010
result = []
1111
for i in range(len(lod0) - 1):
@@ -20,9 +20,9 @@ def CTCDecode(input, lod, blank, merge_repeated):
2020
return result
2121

2222

23-
class TestCTCDecodeOp(OpTest):
23+
class TestCTCAlignOp(OpTest):
2424
def config(self):
25-
self.op_type = "ctc_decode"
25+
self.op_type = "ctc_align"
2626
self.input_lod = [[0, 11, 18]]
2727
self.blank = 0
2828
self.merge_repeated = False
@@ -32,8 +32,8 @@ def config(self):
3232

3333
def setUp(self):
3434
self.config()
35-
output = CTCDecode(self.input, self.input_lod, self.blank,
36-
self.merge_repeated)
35+
output = CTCAlign(self.input, self.input_lod, self.blank,
36+
self.merge_repeated)
3737

3838
self.inputs = {"Input": (self.input, self.input_lod), }
3939
self.outputs = {"Output": output}
@@ -47,9 +47,9 @@ def test_check_output(self):
4747
pass
4848

4949

50-
class TestCTCDecodeOpCase1(TestCTCDecodeOp):
50+
class TestCTCAlignOpCase1(TestCTCAlignOp):
5151
def config(self):
52-
self.op_type = "ctc_decode"
52+
self.op_type = "ctc_align"
5353
self.input_lod = [[0, 11, 19]]
5454
self.blank = 0
5555
self.merge_repeated = True

0 commit comments

Comments
 (0)