Skip to content

Commit 579f684

Browse files
committed
Add ctc_greedy_decode_op
1 parent 8d253e4 commit 579f684

File tree

4 files changed

+364
-0
lines changed

4 files changed

+364
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/ctc_greedy_decode_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("Input"),
26+
"Input of CTCGreedyDecodeOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Output"),
28+
"Output of CTCGreedyDecodeOp should not be null.");
29+
30+
auto input_dims = ctx->GetInputDim("Input");
31+
32+
int sequence_width =
33+
static_cast<int>(framework::product(input_dims) / input_dims[0]);
34+
int blank = ctx->Attrs().Get<int>("blank");
35+
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
36+
"The value of Attr(blank) should be in interval [0, %d).",
37+
sequence_width);
38+
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
39+
ctx->SetOutputDim("Output", {input_dims[0], 1});
40+
}
41+
42+
protected:
43+
framework::OpKernelType GetExpectedKernelType(
44+
const framework::ExecutionContext& ctx) const override {
45+
return framework::OpKernelType(
46+
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
47+
ctx.device_context());
48+
}
49+
};
50+
51+
class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
52+
public:
53+
CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
54+
: OpProtoAndCheckerMaker(proto, op_checker) {
55+
AddInput("Input",
56+
"(LodTensor, default: LoDTensor<float>), the unscaled "
57+
"probabilities of variable-length sequences, which is a 2-D "
58+
"Tensor with LoD information. It's shape is "
59+
"[Lp, num_classes + 1], where Lp is the sum of all input "
60+
"sequences' length and num_classes is the true number of classes "
61+
"(not including the blank label).");
62+
AddOutput("Output", "(Tensor, default: Tensor<int>), the decode result ");
63+
AddAttr<int>("blank",
64+
"(int, default: 0), the blank label setted in Connectionist "
65+
"Temporal Classification (CTC) op, and it is in the "
66+
"half-opened interval [0, num_classes + 1).")
67+
.SetDefault(0);
68+
AddAttr<bool>("merge_repeated",
69+
"(bool, default: true), whether to "
70+
"merge repeated elements between two blanks. ")
71+
.SetDefault(true);
72+
AddComment(R"DOC(
73+
CTCGreedyDecoder is an implementation of the simple best path decoding
74+
algorithm, selecting at each timestep the most likely class at each timestep.
75+
)DOC");
76+
}
77+
};
78+
79+
} // namespace operators
80+
} // namespace paddle
81+
82+
namespace ops = paddle::operators;
83+
REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp,
84+
ops::CTCGreedyDecodeOpMaker,
85+
paddle::framework::EmptyGradOpMaker);
86+
REGISTER_OP_CPU_KERNEL(
87+
ctc_greedy_decode,
88+
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, float>);
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <stdio.h>
16+
#include <thrust/device_vector.h>
17+
#include <thrust/host_vector.h>
18+
#include "paddle/operators/ctc_greedy_decode_op.h"
19+
#include "paddle/platform/cuda_helper.h"
20+
#include "paddle/platform/gpu_info.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
using platform::PADDLE_CUDA_NUM_THREADS;
25+
26+
__device__ static float atomicMaxF(float* address, float val) {
27+
int* address_as_i = (int*)address;
28+
int old = *address_as_i, assumed;
29+
do {
30+
assumed = old;
31+
old = ::atomicCAS(address_as_i, assumed,
32+
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
33+
} while (assumed != old);
34+
return __int_as_float(old);
35+
}
36+
37+
template <typename T, int BlockSize>
38+
__global__ void ArgmaxCudaKernel(const size_t seq_width, const T* logits,
39+
int* output) {
40+
T local_max_value = 0;
41+
int local_max_index = 0;
42+
__shared__ T max_value;
43+
if (threadIdx.x == 0) {
44+
max_value = 0;
45+
}
46+
__syncthreads();
47+
48+
for (int i = threadIdx.x; i < seq_width; i += BlockSize) {
49+
T value = logits[blockIdx.x * seq_width + i];
50+
if (value > local_max_value) {
51+
local_max_value = value;
52+
local_max_index = i;
53+
}
54+
}
55+
56+
atomicMaxF(&max_value, local_max_value);
57+
58+
__syncthreads();
59+
60+
if (local_max_value == max_value) {
61+
output[blockIdx.x] = local_max_index;
62+
}
63+
}
64+
65+
template <typename T>
66+
__global__ void MergeAndDelCudaKernel(const int64_t num_token, int* tokens,
67+
const size_t num_seq, size_t* lod0,
68+
const int blank, const int merge_repeated,
69+
size_t* out_lod0, int* output) {
70+
int ouput_idx = 0;
71+
out_lod0[0] = 0;
72+
73+
for (int i = 0; i < num_seq; ++i) {
74+
int pre_token = -1;
75+
for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
76+
if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
77+
output[ouput_idx] = tokens[j];
78+
++ouput_idx;
79+
}
80+
pre_token = tokens[j];
81+
}
82+
out_lod0[i + 1] = ouput_idx;
83+
}
84+
}
85+
86+
template <typename T>
87+
class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
88+
public:
89+
void Compute(const framework::ExecutionContext& ctx) const override {
90+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
91+
"It must use CUDAPlace.");
92+
auto* input = ctx.Input<LoDTensor>("Input");
93+
auto* output = ctx.Output<LoDTensor>("Output");
94+
95+
const int64_t num_tokens = input->dims()[0];
96+
const size_t seq_width = input->numel() / num_tokens;
97+
const T* logits = input->data<T>();
98+
Tensor tmp;
99+
int* tokens = tmp.mutable_data<int>({num_tokens, 1}, ctx.GetPlace());
100+
// get argmax
101+
// platform::GpuMemsetAsync(args, 0, sizeof(float), stream);
102+
103+
auto stream = ctx.cuda_device_context().stream();
104+
ArgmaxCudaKernel<T, PADDLE_CUDA_NUM_THREADS><<<
105+
num_tokens, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(seq_width, logits,
106+
tokens);
107+
108+
const size_t level = 0;
109+
auto input_lod = framework::ToAbsOffset(input->lod());
110+
const size_t num_seq = input_lod[level].size() - 1;
111+
const int blank = ctx.Attr<int>("blank");
112+
const int merge_repeated =
113+
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
114+
115+
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
116+
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
117+
118+
int* output_data =
119+
output->mutable_data<int>({num_tokens, 1}, ctx.GetPlace());
120+
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
121+
num_tokens, tokens, num_seq, input_lod[level].data(), blank,
122+
merge_repeated, dev_out_lod0_ptr, output_data);
123+
124+
thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(),
125+
dev_out_lod0.end());
126+
framework::LoD out_lod;
127+
out_lod.push_back(host_out_lod0);
128+
output->set_lod(out_lod);
129+
130+
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
131+
}
132+
};
133+
134+
} // namespace operators
135+
} // namespace paddle
136+
137+
REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode,
138+
paddle::operators::CTCGreedyDecodeOpCUDAKernel<float>);
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string.h>
18+
#include "paddle/framework/op_registry.h"
19+
#include "unsupported/Eigen/CXX11/Tensor"
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
using LoDTensor = framework::LoDTensor;
25+
26+
template <typename DeviceContext, typename T>
27+
class CTCGreedyDecodeKernel : public framework::OpKernel<T> {
28+
public:
29+
void Compute(const framework::ExecutionContext& ctx) const override {
30+
auto* input = ctx.Input<LoDTensor>("Input");
31+
auto* output = ctx.Output<LoDTensor>("Output");
32+
const size_t level = 0;
33+
34+
auto input_lod = framework::ToAbsOffset(input->lod());
35+
auto input_dims = input->dims();
36+
PADDLE_ENFORCE_EQ(input_dims[0],
37+
static_cast<int64_t>(input_lod[level].back()),
38+
"The first dimension of Input(Input) should be equal to "
39+
"the sum of all sequences' lengths.");
40+
41+
const size_t num_sequences = input_lod[level].size() - 1;
42+
const size_t sequence_width = input->numel() / input_dims[0];
43+
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
44+
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
45+
std::vector<std::vector<int>> pathes(num_sequences);
46+
std::vector<size_t> output_lod0(1, 0);
47+
48+
const T* input_data = input->data<T>();
49+
Eigen::Map<
50+
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
51+
input_mat(const_cast<T*>(input_data), input->numel() / sequence_width,
52+
sequence_width);
53+
54+
size_t max_class_idx;
55+
size_t prev_class_idx = -1;
56+
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
57+
for (size_t i = input_lod[level][seq_idx];
58+
i < input_lod[level][seq_idx + 1]; ++i) {
59+
input_mat.row(i).maxCoeff(&max_class_idx);
60+
if (max_class_idx != blank &&
61+
!(merge_repeated && max_class_idx == prev_class_idx)) {
62+
pathes[seq_idx].push_back(max_class_idx);
63+
}
64+
prev_class_idx = max_class_idx;
65+
}
66+
output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size());
67+
}
68+
framework::LoD output_lod;
69+
output_lod.push_back(output_lod0);
70+
output->set_lod(output_lod);
71+
int64_t num_step = static_cast<int64_t>(output_lod0.back());
72+
int* output_data = output->mutable_data<int>({num_step, 1}, ctx.GetPlace());
73+
74+
for (int i = 0; i < num_sequences; ++i) {
75+
memcpy(output_data + output_lod0[i], pathes[i].data(),
76+
sizeof(int) * pathes[i].size());
77+
}
78+
}
79+
};
80+
81+
} // namespace operators
82+
} // namespace paddle
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import sys
2+
import unittest
3+
import numpy as np
4+
from op_test import OpTest
5+
from test_softmax_op import stable_softmax
6+
7+
8+
def CTCGreedyDecode(softmax, blank, merge_repeated):
9+
prev_token = -1
10+
result = []
11+
for token in np.argmax(softmax, axis=1):
12+
if (token != blank) and not (merge_repeated and token == prev_token):
13+
result.append(token)
14+
return np.array(result).reshape([len(result), 1])
15+
16+
17+
class TestCTCGreedyDecodeOp(OpTest):
18+
def config(self):
19+
self.op_type = "ctc_greedy_decode"
20+
self.batch_size = 4
21+
self.num_classes = 8
22+
self.input_lod = [[0, 4, 5, 8, 11]]
23+
self.blank = 7
24+
self.merge_repeated = True
25+
26+
def setUp(self):
27+
self.config()
28+
input = np.random.uniform(
29+
0.1, 1.0,
30+
[self.input_lod[0][-1], self.num_classes]).astype("float32")
31+
softmax = np.apply_along_axis(stable_softmax, 1, input)
32+
output = CTCGreedyDecode(softmax, self.blank, self.merge_repeated)
33+
34+
self.inputs = {"Input": (softmax, self.input_lod), }
35+
self.outputs = {"Output": output}
36+
self.attrs = {
37+
"blank": self.blank,
38+
"merge_repeated": self.merge_repeated
39+
}
40+
41+
def test_check_output(self):
42+
self.check_output()
43+
44+
45+
class TestCTCGreedyDecodeOpCase1(TestCTCGreedyDecodeOp):
46+
def config(self):
47+
self.op_type = "ctc_greedy_decode"
48+
self.batch_size = 4
49+
self.num_classes = 1025
50+
self.input_lod = [[0, 4, 5, 8, 11]]
51+
self.blank = 0
52+
self.merge_repeated = True
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)