Skip to content

Commit 281e93b

Browse files
committed
Remove 'top 1' from CPU and GPU kernel
1. Remove 'top 1'(or argmax) from CPU and GPU kernel 2. Add a new test case 3. Refine doc
1 parent 579f684 commit 281e93b

File tree

5 files changed

+118
-155
lines changed

5 files changed

+118
-155
lines changed

paddle/operators/ctc_greedy_decode_op.cc renamed to paddle/operators/ctc_decode_op.cc

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,8 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
2929

3030
auto input_dims = ctx->GetInputDim("Input");
3131

32-
int sequence_width =
33-
static_cast<int>(framework::product(input_dims) / input_dims[0]);
34-
int blank = ctx->Attrs().Get<int>("blank");
35-
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
36-
"The value of Attr(blank) should be in interval [0, %d).",
37-
sequence_width);
3832
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
39-
ctx->SetOutputDim("Output", {input_dims[0], 1});
33+
ctx->SetOutputDim("Output", input_dims);
4034
}
4135

4236
protected:
@@ -53,25 +47,37 @@ class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
5347
CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
5448
: OpProtoAndCheckerMaker(proto, op_checker) {
5549
AddInput("Input",
56-
"(LodTensor, default: LoDTensor<float>), the unscaled "
57-
"probabilities of variable-length sequences, which is a 2-D "
58-
"Tensor with LoD information. It's shape is "
59-
"[Lp, num_classes + 1], where Lp is the sum of all input "
60-
"sequences' length and num_classes is the true number of classes "
61-
"(not including the blank label).");
62-
AddOutput("Output", "(Tensor, default: Tensor<int>), the decode result ");
50+
"(LodTensor, default: LoDTensor<int>), Its shape is "
51+
"[Lp, 1], where Lp is the sum of all input sequences' length.");
52+
AddOutput("Output", "(Tensor, default: Tensor<int>), The decode result.");
6353
AddAttr<int>("blank",
6454
"(int, default: 0), the blank label setted in Connectionist "
65-
"Temporal Classification (CTC) op, and it is in the "
66-
"half-opened interval [0, num_classes + 1).")
55+
"Temporal Classification (CTC) op.")
6756
.SetDefault(0);
6857
AddAttr<bool>("merge_repeated",
6958
"(bool, default: true), whether to "
7059
"merge repeated elements between two blanks. ")
7160
.SetDefault(true);
7261
AddComment(R"DOC(
73-
CTCGreedyDecoder is an implementation of the simple best path decoding
74-
algorithm, selecting at each timestep the most likely class at each timestep.
62+
CTCDecoder is used to merge repeated elements between two blanks
63+
and then delete all blanks in sequence.
64+
65+
Given:
66+
Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6,
67+
6, 0, 0, 7, 7, 7, 0]
68+
Input.dims = {18, 1}
69+
Input.LoD = [[0, 11, 18]]
70+
71+
And:
72+
blank = 0
73+
merge_repeated = True
74+
75+
Then:
76+
Output.data = [1, 2, 4, 4, 5, 6,
77+
6, 7]
78+
Output.dims = {8, 1}
79+
Output.LoD = [[0, 6, 8]]
80+
7581
)DOC");
7682
}
7783
};
@@ -85,4 +91,4 @@ REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp,
8591
paddle::framework::EmptyGradOpMaker);
8692
REGISTER_OP_CPU_KERNEL(
8793
ctc_greedy_decode,
88-
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, float>);
94+
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, int>);

paddle/operators/ctc_greedy_decode_op.cu renamed to paddle/operators/ctc_decode_op.cu

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,62 +16,20 @@ limitations under the License. */
1616
#include <thrust/device_vector.h>
1717
#include <thrust/host_vector.h>
1818
#include "paddle/operators/ctc_greedy_decode_op.h"
19-
#include "paddle/platform/cuda_helper.h"
20-
#include "paddle/platform/gpu_info.h"
2119

2220
namespace paddle {
2321
namespace operators {
24-
using platform::PADDLE_CUDA_NUM_THREADS;
25-
26-
__device__ static float atomicMaxF(float* address, float val) {
27-
int* address_as_i = (int*)address;
28-
int old = *address_as_i, assumed;
29-
do {
30-
assumed = old;
31-
old = ::atomicCAS(address_as_i, assumed,
32-
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
33-
} while (assumed != old);
34-
return __int_as_float(old);
35-
}
36-
37-
template <typename T, int BlockSize>
38-
__global__ void ArgmaxCudaKernel(const size_t seq_width, const T* logits,
39-
int* output) {
40-
T local_max_value = 0;
41-
int local_max_index = 0;
42-
__shared__ T max_value;
43-
if (threadIdx.x == 0) {
44-
max_value = 0;
45-
}
46-
__syncthreads();
47-
48-
for (int i = threadIdx.x; i < seq_width; i += BlockSize) {
49-
T value = logits[blockIdx.x * seq_width + i];
50-
if (value > local_max_value) {
51-
local_max_value = value;
52-
local_max_index = i;
53-
}
54-
}
55-
56-
atomicMaxF(&max_value, local_max_value);
57-
58-
__syncthreads();
59-
60-
if (local_max_value == max_value) {
61-
output[blockIdx.x] = local_max_index;
62-
}
63-
}
6422

6523
template <typename T>
66-
__global__ void MergeAndDelCudaKernel(const int64_t num_token, int* tokens,
24+
__global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
6725
const size_t num_seq, size_t* lod0,
6826
const int blank, const int merge_repeated,
69-
size_t* out_lod0, int* output) {
27+
size_t* out_lod0, T* output) {
7028
int ouput_idx = 0;
7129
out_lod0[0] = 0;
7230

7331
for (int i = 0; i < num_seq; ++i) {
74-
int pre_token = -1;
32+
T pre_token = -1;
7533
for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
7634
if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
7735
output[ouput_idx] = tokens[j];
@@ -89,44 +47,39 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
8947
void Compute(const framework::ExecutionContext& ctx) const override {
9048
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
9149
"It must use CUDAPlace.");
50+
const size_t level = 0;
9251
auto* input = ctx.Input<LoDTensor>("Input");
9352
auto* output = ctx.Output<LoDTensor>("Output");
53+
auto input_lod = framework::ToAbsOffset(input->lod());
9454

55+
const T* tokens = input->data<T>();
9556
const int64_t num_tokens = input->dims()[0];
96-
const size_t seq_width = input->numel() / num_tokens;
97-
const T* logits = input->data<T>();
98-
Tensor tmp;
99-
int* tokens = tmp.mutable_data<int>({num_tokens, 1}, ctx.GetPlace());
100-
// get argmax
101-
// platform::GpuMemsetAsync(args, 0, sizeof(float), stream);
102-
103-
auto stream = ctx.cuda_device_context().stream();
104-
ArgmaxCudaKernel<T, PADDLE_CUDA_NUM_THREADS><<<
105-
num_tokens, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(seq_width, logits,
106-
tokens);
107-
108-
const size_t level = 0;
109-
auto input_lod = framework::ToAbsOffset(input->lod());
11057
const size_t num_seq = input_lod[level].size() - 1;
58+
11159
const int blank = ctx.Attr<int>("blank");
11260
const int merge_repeated =
11361
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
11462

63+
// prepare a lod to record lod information while merging elements
11564
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
11665
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
11766

118-
int* output_data =
119-
output->mutable_data<int>({num_tokens, 1}, ctx.GetPlace());
67+
// merge elements and delete blank
68+
T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace());
69+
70+
auto stream = ctx.cuda_device_context().stream();
12071
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
12172
num_tokens, tokens, num_seq, input_lod[level].data(), blank,
12273
merge_repeated, dev_out_lod0_ptr, output_data);
12374

75+
// set output lod
12476
thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(),
12577
dev_out_lod0.end());
12678
framework::LoD out_lod;
12779
out_lod.push_back(host_out_lod0);
12880
output->set_lod(out_lod);
12981

82+
// resize output dims
13083
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
13184
}
13285
};
@@ -135,4 +88,4 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
13588
} // namespace paddle
13689

13790
REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode,
138-
paddle::operators::CTCGreedyDecodeOpCUDAKernel<float>);
91+
paddle::operators::CTCGreedyDecodeOpCUDAKernel<int>);

paddle/operators/ctc_greedy_decode_op.h renamed to paddle/operators/ctc_decode_op.h

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License. */
1616

1717
#include <string.h>
1818
#include "paddle/framework/op_registry.h"
19-
#include "unsupported/Eigen/CXX11/Tensor"
2019
namespace paddle {
2120
namespace operators {
2221

@@ -30,47 +29,46 @@ class CTCGreedyDecodeKernel : public framework::OpKernel<T> {
3029
auto* input = ctx.Input<LoDTensor>("Input");
3130
auto* output = ctx.Output<LoDTensor>("Output");
3231
const size_t level = 0;
33-
3432
auto input_lod = framework::ToAbsOffset(input->lod());
33+
34+
// check input dims and lod
3535
auto input_dims = input->dims();
3636
PADDLE_ENFORCE_EQ(input_dims[0],
3737
static_cast<int64_t>(input_lod[level].back()),
3838
"The first dimension of Input(Input) should be equal to "
3939
"the sum of all sequences' lengths.");
4040

4141
const size_t num_sequences = input_lod[level].size() - 1;
42-
const size_t sequence_width = input->numel() / input_dims[0];
4342
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
4443
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
44+
45+
// merge repeated tokens and delete blank
4546
std::vector<std::vector<int>> pathes(num_sequences);
4647
std::vector<size_t> output_lod0(1, 0);
47-
4848
const T* input_data = input->data<T>();
49-
Eigen::Map<
50-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
51-
input_mat(const_cast<T*>(input_data), input->numel() / sequence_width,
52-
sequence_width);
53-
54-
size_t max_class_idx;
55-
size_t prev_class_idx = -1;
5649
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
50+
T prev_token = -1;
5751
for (size_t i = input_lod[level][seq_idx];
5852
i < input_lod[level][seq_idx + 1]; ++i) {
59-
input_mat.row(i).maxCoeff(&max_class_idx);
60-
if (max_class_idx != blank &&
61-
!(merge_repeated && max_class_idx == prev_class_idx)) {
62-
pathes[seq_idx].push_back(max_class_idx);
53+
if (input_data[i] != blank &&
54+
!(merge_repeated && input_data[i] == prev_token)) {
55+
pathes[seq_idx].push_back(input_data[i]);
6356
}
64-
prev_class_idx = max_class_idx;
57+
prev_token = input_data[i];
6558
}
6659
output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size());
6760
}
61+
62+
// set output lod
6863
framework::LoD output_lod;
6964
output_lod.push_back(output_lod0);
7065
output->set_lod(output_lod);
71-
int64_t num_step = static_cast<int64_t>(output_lod0.back());
72-
int* output_data = output->mutable_data<int>({num_step, 1}, ctx.GetPlace());
7366

67+
// resize output dims
68+
T* output_data = output->mutable_data<T>(
69+
{static_cast<int64_t>(output_lod0.back()), 1}, ctx.GetPlace());
70+
71+
// copy result to output
7472
for (int i = 0; i < num_sequences; ++i) {
7573
memcpy(output_data + output_lod0[i], pathes[i].data(),
7674
sizeof(int) * pathes[i].size());
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import sys
2+
import unittest
3+
import numpy as np
4+
from op_test import OpTest
5+
from test_softmax_op import stable_softmax
6+
7+
8+
def CTCDecode(input, lod, blank, merge_repeated):
9+
lod0 = lod[0]
10+
result = []
11+
for i in range(len(lod0) - 1):
12+
prev_token = -1
13+
for j in range(lod0[i], lod0[i + 1]):
14+
token = input[j][0]
15+
if (token != blank) and not (merge_repeated and
16+
token == prev_token):
17+
result.append(token)
18+
prev_token = token
19+
result = np.array(result).reshape([len(result), 1]).astype("int32")
20+
return result
21+
22+
23+
class TestCTCDecodeOp(OpTest):
24+
def config(self):
25+
self.op_type = "ctc_greedy_decode"
26+
self.input_lod = [[0, 11, 18]]
27+
self.blank = 0
28+
self.merge_repeated = False
29+
self.input = np.array(
30+
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
31+
[18, 1]).astype("int32")
32+
33+
def setUp(self):
34+
self.config()
35+
output = CTCDecode(self.input, self.input_lod, self.blank,
36+
self.merge_repeated)
37+
38+
self.inputs = {"Input": (self.input, self.input_lod), }
39+
self.outputs = {"Output": output}
40+
self.attrs = {
41+
"blank": self.blank,
42+
"merge_repeated": self.merge_repeated
43+
}
44+
45+
def test_check_output(self):
46+
self.check_output()
47+
pass
48+
49+
50+
class TestCTCDecodeOpCase1(TestCTCDecodeOp):
51+
def config(self):
52+
self.op_type = "ctc_greedy_decode"
53+
self.input_lod = [[0, 11, 18]]
54+
self.blank = 0
55+
self.merge_repeated = True
56+
self.input = np.array(
57+
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
58+
[18, 1]).astype("int32")
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

0 commit comments

Comments
 (0)