Skip to content

Commit 861b84f

Browse files
author
Yibing Liu
authored
Merge pull request #5300 from kuke/ctc_edit_distance_dev
Add edit distance operator
2 parents 377424b + fe0ef91 commit 861b84f

File tree

4 files changed

+437
-0
lines changed

4 files changed

+437
-0
lines changed

paddle/operators/edit_distance_op.cc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/edit_distance_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class EditDistanceOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext *ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
26+
PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
28+
auto hyp_dims = ctx->GetInputDim("Hyps");
29+
auto ref_dims = ctx->GetInputDim("Refs");
30+
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
31+
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
32+
"equal to 1.");
33+
PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1,
34+
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
35+
"equal to 1.");
36+
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
37+
}
38+
39+
protected:
40+
framework::OpKernelType GetExpectedKernelType(
41+
const framework::ExecutionContext &ctx) const override {
42+
return framework::OpKernelType(framework::proto::DataType::FP32,
43+
ctx.device_context());
44+
}
45+
};
46+
47+
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
48+
public:
49+
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
50+
: OpProtoAndCheckerMaker(proto, op_checker) {
51+
AddInput("Hyps",
52+
"(2-D LoDTensor<int>, 2nd dim. equal to 1) "
53+
"The indices for hypothesis strings.");
54+
AddInput("Refs",
55+
"(2-D LoDTensor<int>, 2nd dim. equal to 1) "
56+
"The indices for reference strings.");
57+
AddAttr<bool>("normalized",
58+
"(bool, default false) Indicated whether to normalize "
59+
"the edit distance by the length of reference string.")
60+
.SetDefault(false);
61+
AddOutput("Out",
62+
"(2-D Tensor with shape [`batch_size` x 1]) "
63+
"The output edit distances of EditDistance operator.");
64+
AddComment(R"DOC(
65+
66+
EditDistance operator computes the edit distances between a batch of hypothesis
67+
strings and their references.
68+
69+
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
70+
are by counting the minimum number of operations to transform one string into anthor.
71+
Here the operations include insertion, deletion, and substitution. For example,
72+
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance
73+
is 3 for A will be transformed into B at least after two substitutions and one
74+
insertion:
75+
76+
"kitten" -> "sitten" -> "sittin" -> "sitting"
77+
78+
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
79+
number denoted by `batch_size`, and the separation is specified by the LoD information.
80+
And the `batch_size` reference strings are arranged in order in the same way in the
81+
LoDTensor Input(Refs).
82+
83+
Output(Out) contains the `batch_size` results and each stands for the edit stance
84+
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
85+
will be divided by the length of reference string.
86+
)DOC");
87+
}
88+
};
89+
90+
} // namespace operators
91+
} // namespace paddle
92+
93+
namespace ops = paddle::operators;
94+
95+
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
96+
paddle::framework::EmptyGradOpMaker);
97+
REGISTER_OP_CPU_KERNEL(
98+
edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);

paddle/operators/edit_distance_op.cu

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include "paddle/framework/op_registry.h"
17+
#include "paddle/platform/cuda_helper.h"
18+
#include "paddle/platform/gpu_info.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using platform::PADDLE_CUDA_NUM_THREADS;
24+
25+
template <typename T>
26+
__global__ void FillFirstRow(T* dist, const int N) {
27+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
28+
if (idx < N + 1) {
29+
dist[idx] = idx;
30+
}
31+
}
32+
33+
template <typename T>
34+
__global__ void FillFirstColumn(T* dist, const int M, const int N) {
35+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
36+
if (idx < M + 1) {
37+
dist[idx * (N + 1)] = idx;
38+
}
39+
}
40+
41+
template <typename T>
42+
__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M,
43+
const int N, const int start) {
44+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
45+
int offset = N;
46+
int index = start + idx * offset;
47+
int row = index / (N + 1);
48+
int col = index % (N + 1);
49+
if (row > 0 && col > 0 && row < M + 1 && col < N + 1) {
50+
int cost = x1[row - 1] == x2[col - 1] ? 0 : 1;
51+
int dels = dist[(row - 1) * (N + 1) + col] + 1;
52+
int ins = dist[row * (N + 1) + col - 1] + 1;
53+
int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost;
54+
dist[index] = min(dels, min(ins, subs));
55+
}
56+
}
57+
58+
template <typename T>
59+
__global__ void SetOutput(T* out, const T* dist, const int M, const int N,
60+
bool normalized) {
61+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
62+
if (idx == 0) {
63+
out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N];
64+
}
65+
}
66+
67+
template <typename Place, typename T>
68+
class EditDistanceGPUKernel : public framework::OpKernel<T> {
69+
public:
70+
void Compute(const framework::ExecutionContext& ctx) const {
71+
auto* out_t = ctx.Output<framework::Tensor>("Out");
72+
73+
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
74+
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
75+
76+
auto normalized = ctx.Attr<bool>("normalized");
77+
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
78+
ctx.device_context())
79+
.stream();
80+
81+
auto hyp_lod = x1_t->lod()[0];
82+
auto ref_lod = x2_t->lod()[0];
83+
PADDLE_ENFORCE(
84+
hyp_lod.size() == ref_lod.size(),
85+
"Input(Hyps) and Input(Refs) must have the same batch size.");
86+
for (size_t i = 1; i < ref_lod.size(); ++i) {
87+
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
88+
"Reference string %d is empty.", i);
89+
}
90+
91+
auto num_strs = hyp_lod.size() - 1;
92+
out_t->Resize({static_cast<int64_t>(num_strs), 1});
93+
out_t->mutable_data<T>(ctx.GetPlace());
94+
auto out = out_t->data<T>();
95+
96+
T distance = 0.0;
97+
for (size_t num = 0; num < num_strs; num++) {
98+
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
99+
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
100+
if (m == 0 || n == 0) {
101+
distance = std::max(m, n);
102+
if (normalized) {
103+
PADDLE_ENFORCE(n > 0,
104+
"The reference string (#%d) cannot be empty "
105+
"when Attr(normalized) is enabled.",
106+
n);
107+
distance = distance / n;
108+
}
109+
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num,
110+
platform::CPUPlace(), &distance, sizeof(T), stream);
111+
} else {
112+
framework::Tensor dist_t;
113+
dist_t.Resize({m + 1, n + 1});
114+
dist_t.mutable_data<T>(ctx.GetPlace());
115+
auto dist = dist_t.data<T>();
116+
auto x1 = x1_t->data<int>() + hyp_lod[num];
117+
auto x2 = x2_t->data<int>() + ref_lod[num];
118+
119+
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS,
120+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n);
121+
122+
FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS,
123+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n);
124+
// Compute the elements of distance matrix in the anti-diagonal diretion
125+
for (int64_t slice = 2; slice < m + n + 1; ++slice) {
126+
int z_m = slice < m + 1 ? 0 : slice - m;
127+
int z_n = slice < n + 1 ? 0 : slice - n;
128+
int size = slice - (z_m + z_n) + 1; // number of elments in the same
129+
// anti-diagonal line to update
130+
// the start index at which computes from
131+
int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1;
132+
Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS,
133+
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2,
134+
m, n, start);
135+
}
136+
SetOutput<T><<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized);
137+
}
138+
}
139+
}
140+
};
141+
142+
} // namespace operators
143+
} // namespace paddle
144+
145+
namespace ops = paddle::operators;
146+
147+
REGISTER_OP_CUDA_KERNEL(
148+
edit_distance,
149+
ops::EditDistanceGPUKernel<paddle::platform::CUDAPlace, float>);

paddle/operators/edit_distance_op.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include "paddle/framework/eigen.h"
18+
#include "paddle/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename Place, typename T>
24+
class EditDistanceKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const {
27+
auto* out_t = ctx.Output<framework::Tensor>("Out");
28+
29+
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
30+
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
31+
32+
auto normalized = ctx.Attr<bool>("normalized");
33+
34+
auto hyp_lod = x1_t->lod()[0];
35+
auto ref_lod = x2_t->lod()[0];
36+
PADDLE_ENFORCE(
37+
hyp_lod.size() == ref_lod.size(),
38+
"Input(Hyps) and Input(Refs) must have the same batch size.");
39+
for (size_t i = 1; i < ref_lod.size(); ++i) {
40+
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1],
41+
"Reference string %d is empty.", i);
42+
}
43+
auto num_strs = hyp_lod.size() - 1;
44+
45+
out_t->Resize({static_cast<int64_t>(num_strs), 1});
46+
out_t->mutable_data<float>(ctx.GetPlace());
47+
auto out = out_t->data<T>();
48+
49+
T distance = 0.0;
50+
for (size_t num = 0; num < num_strs; ++num) {
51+
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]);
52+
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]);
53+
54+
if (m == 0) {
55+
distance = n;
56+
} else if (n == 0) {
57+
distance = m;
58+
} else {
59+
framework::Tensor dist_t;
60+
dist_t.Resize({m + 1, n + 1});
61+
dist_t.mutable_data<T>(ctx.GetPlace());
62+
auto dist = dist_t.data<T>();
63+
auto x1 = x1_t->data<int>() + hyp_lod[num];
64+
auto x2 = x2_t->data<int>() + ref_lod[num];
65+
for (int64_t i = 0; i < m + 1; ++i) {
66+
dist[i * (n + 1)] = i;
67+
}
68+
for (int64_t j = 0; j < n + 1; ++j) {
69+
dist[j] = j;
70+
}
71+
for (int64_t i = 1; i < m + 1; ++i) {
72+
for (int64_t j = 1; j < n + 1; ++j) {
73+
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
74+
int dels = dist[(i - 1) * (n + 1) + j] + 1;
75+
int ins = dist[i * (n + 1) + (j - 1)] + 1;
76+
int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
77+
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
78+
}
79+
}
80+
distance = dist[m * (n + 1) + n];
81+
}
82+
83+
if (normalized) {
84+
PADDLE_ENFORCE(n > 0,
85+
"The reference string (#%d) cannot be empty "
86+
"when Attr(normalized) is enabled.",
87+
n);
88+
distance = distance / n;
89+
}
90+
out[num] = distance;
91+
}
92+
}
93+
};
94+
95+
} // namespace operators
96+
} // namespace paddle

0 commit comments

Comments
 (0)